|
53 | 53 |
|
54 | 54 | from nemo_automodel._transformers.utils import apply_cache_compatibility_patches |
55 | 55 | from nemo_automodel.components.distributed.parallelizer import _get_parallel_plan |
| 56 | +from nemo_automodel.components.models.llama.model import LlamaForCausalLM as CustomLlamaForCausalLM |
56 | 57 | from nemo_automodel.components.models.mistral3.model import Ministral3Config, Ministral3ForCausalLM |
| 58 | +from transformers import LlamaConfig |
57 | 59 | from transformers.models.qwen3.configuration_qwen3 import Qwen3Config |
58 | 60 | from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3ForSequenceClassification |
59 | 61 |
|
60 | | -ModelKind = Literal["qwen3", "qwen3_seq_cls", "ministral3"] |
| 62 | +ModelKind = Literal["qwen3", "qwen3_seq_cls", "ministral3", "llama"] |
61 | 63 | SPMode = Literal["true", "false", "both"] |
62 | 64 |
|
63 | 65 |
|
@@ -205,6 +207,22 @@ def _build_minified_model(kind: ModelKind): |
205 | 207 | ) |
206 | 208 | return cfg, Qwen3ForSequenceClassification(cfg) |
207 | 209 |
|
| 210 | + if kind == "llama": |
| 211 | + cfg = LlamaConfig( |
| 212 | + vocab_size=128, |
| 213 | + hidden_size=64, |
| 214 | + intermediate_size=256, |
| 215 | + num_hidden_layers=2, |
| 216 | + num_attention_heads=4, |
| 217 | + num_key_value_heads=2, |
| 218 | + max_position_embeddings=128, |
| 219 | + use_cache=False, |
| 220 | + tie_word_embeddings=True, |
| 221 | + attention_bias=False, |
| 222 | + attn_implementation="eager", |
| 223 | + ) |
| 224 | + return cfg, CustomLlamaForCausalLM(cfg) |
| 225 | + |
208 | 226 | raise ValueError(f"Unknown model kind: {kind}") |
209 | 227 |
|
210 | 228 |
|
@@ -273,8 +291,8 @@ def main(argv: Sequence[str] | None = None) -> int: |
273 | 291 | parser.add_argument( |
274 | 292 | "--models", |
275 | 293 | nargs="+", |
276 | | - default=["qwen3", "qwen3_seq_cls", "ministral3"], |
277 | | - choices=["qwen3", "qwen3_seq_cls", "ministral3"], |
| 294 | + default=["qwen3", "qwen3_seq_cls", "ministral3", "llama"], |
| 295 | + choices=["qwen3", "qwen3_seq_cls", "ministral3", "llama"], |
278 | 296 | help="Which models to test.", |
279 | 297 | ) |
280 | 298 | parser.add_argument( |
|
0 commit comments