Skip to content

Commit a18c307

Browse files
committed
add llama
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent e7677e9 commit a18c307

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

tests/functional_tests/llm_pretrain_and_kd/L2_TP_Output_Parity_Minified.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ KL_THRESHOLD="${KL_THRESHOLD:-1e-6}"
2323

2424
torchrun --nproc_per_node=2 --nnodes=1 \
2525
tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py \
26-
--models qwen3 qwen3_seq_cls ministral3 \
26+
--models qwen3 qwen3_seq_cls ministral3 llama \
2727
--sequence_parallel both \
2828
--kl_threshold "${KL_THRESHOLD}"
2929

tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@
5353

5454
from nemo_automodel._transformers.utils import apply_cache_compatibility_patches
5555
from nemo_automodel.components.distributed.parallelizer import _get_parallel_plan
56+
from nemo_automodel.components.models.llama.model import LlamaForCausalLM as CustomLlamaForCausalLM
5657
from nemo_automodel.components.models.mistral3.model import Ministral3Config, Ministral3ForCausalLM
58+
from transformers import LlamaConfig
5759
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
5860
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3ForSequenceClassification
5961

60-
ModelKind = Literal["qwen3", "qwen3_seq_cls", "ministral3"]
62+
ModelKind = Literal["qwen3", "qwen3_seq_cls", "ministral3", "llama"]
6163
SPMode = Literal["true", "false", "both"]
6264

6365

@@ -205,6 +207,22 @@ def _build_minified_model(kind: ModelKind):
205207
)
206208
return cfg, Qwen3ForSequenceClassification(cfg)
207209

210+
if kind == "llama":
211+
cfg = LlamaConfig(
212+
vocab_size=128,
213+
hidden_size=64,
214+
intermediate_size=256,
215+
num_hidden_layers=2,
216+
num_attention_heads=4,
217+
num_key_value_heads=2,
218+
max_position_embeddings=128,
219+
use_cache=False,
220+
tie_word_embeddings=True,
221+
attention_bias=False,
222+
attn_implementation="eager",
223+
)
224+
return cfg, CustomLlamaForCausalLM(cfg)
225+
208226
raise ValueError(f"Unknown model kind: {kind}")
209227

210228

@@ -273,8 +291,8 @@ def main(argv: Sequence[str] | None = None) -> int:
273291
parser.add_argument(
274292
"--models",
275293
nargs="+",
276-
default=["qwen3", "qwen3_seq_cls", "ministral3"],
277-
choices=["qwen3", "qwen3_seq_cls", "ministral3"],
294+
default=["qwen3", "qwen3_seq_cls", "ministral3", "llama"],
295+
choices=["qwen3", "qwen3_seq_cls", "ministral3", "llama"],
278296
help="Which models to test.",
279297
)
280298
parser.add_argument(

0 commit comments

Comments
 (0)