Skip to content

Commit d86b960

Browse files
adil-athomasdhc
andauthored
feat: nano-v3 custom model (#1091)
Signed-off-by: adil-a <adil.asif2000@hotmail.com> Signed-off-by: adil-a <adil-a@users.noreply.github.com> Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> Signed-off-by: thomasdhc <thomasdhc@users.noreply.github.com> Co-authored-by: adil-a <adil-a@users.noreply.github.com> Co-authored-by: Dong Hyuk Chang <donghyukc@nvidia.com> Co-authored-by: thomasdhc <thomasdhc@users.noreply.github.com>
1 parent 5fd6482 commit d86b960

File tree

26 files changed

+3259
-129
lines changed

26 files changed

+3259
-129
lines changed

.github/actions/build-container/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ runs:
130130
context: .
131131
build-args: |
132132
BASE_IMAGE=pytorch
133+
INSTALL_MAMBA=True
133134
cache-from: |
134135
type=registry,ref=${{ env.container-registry }}/${{ env.REPO_LOWER }}:${{ fromJSON(steps.get-pr-info.outputs.pr-info || '{}').number || 0 }}-buildcache,mode=max
135136
type=registry,ref=${{ env.container-registry }}/${{ env.REPO_LOWER }}:main-buildcache,mode=max

.github/actions/test-template/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,4 @@ runs:
254254
path: |
255255
coverage.xml
256256
.coverage
257-
include-hidden-files: true
257+
include-hidden-files: true

docker/common/uv-pytorch.lock

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,17 @@ wheels = [
575575
{ url = "https://files.pythonhosted.org/packages/96/c5/1e741d26306c42e2bf6ab740b2202872727e0f606033c9dd713f8b93f5a8/cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701", size = 11280, upload-time = "2025-10-12T14:55:28.382Z" },
576576
]
577577

578+
[[package]]
579+
name = "causal-conv1d"
580+
version = "1.6.0"
581+
source = { registry = "https://pypi.org/simple" }
582+
dependencies = [
583+
{ name = "ninja" },
584+
{ name = "packaging" },
585+
{ name = "torch", marker = "sys_platform == 'never'" },
586+
]
587+
sdist = { url = "https://files.pythonhosted.org/packages/db/df/63a384c49743b9fc8fec4c05dbd0b515e1c1c2b07e4559acc4fc37c69223/causal_conv1d-1.6.0.tar.gz", hash = "sha256:4eae3220d08e1e88238f3a0a88783147cbdf47f612cc610add75127c7a37ca3e", size = 29356, upload-time = "2026-01-12T17:33:32.794Z" }
588+
578589
[[package]]
579590
name = "certifi"
580591
version = "2025.10.5"
@@ -2388,6 +2399,21 @@ wheels = [
23882399
{ url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" },
23892400
]
23902401

2402+
[[package]]
2403+
name = "mamba-ssm"
2404+
version = "2.3.0"
2405+
source = { registry = "https://pypi.org/simple" }
2406+
dependencies = [
2407+
{ name = "einops" },
2408+
{ name = "ninja" },
2409+
{ name = "packaging" },
2410+
{ name = "setuptools" },
2411+
{ name = "torch", marker = "sys_platform == 'never'" },
2412+
{ name = "transformers" },
2413+
{ name = "triton", marker = "sys_platform == 'never'" },
2414+
]
2415+
sdist = { url = "https://files.pythonhosted.org/packages/54/69/a87f06d9dba78c041adb81f2228e978aab179477c64f1a210c0fe0d63e8d/mamba_ssm-2.3.0.tar.gz", hash = "sha256:8294e12125f76021e4e190f4137e84a84935920eeda5d0037a6917524456b303", size = 121116, upload-time = "2026-01-12T17:07:22.152Z" }
2416+
23912417
[[package]]
23922418
name = "markdown-it-py"
23932419
version = "3.0.0"
@@ -3068,7 +3094,9 @@ all = [
30683094
{ name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" },
30693095
]
30703096
cuda = [
3097+
{ name = "causal-conv1d" },
30713098
{ name = "flash-attn", marker = "sys_platform == 'never'" },
3099+
{ name = "mamba-ssm" },
30723100
{ name = "transformer-engine", extra = ["pytorch"] },
30733101
]
30743102
delta-databricks = [
@@ -3126,6 +3154,7 @@ test = [
31263154
requires-dist = [
31273155
{ name = "albumentations", marker = "extra == 'vlm'" },
31283156
{ name = "backoff", marker = "extra == 'vlm'" },
3157+
{ name = "causal-conv1d", marker = "extra == 'cuda'" },
31293158
{ name = "databricks-sql-connector", marker = "extra == 'delta-databricks'", specifier = ">=3.0.0" },
31303159
{ name = "datasets", specifier = ">=4.0.0" },
31313160
{ name = "deltalake", marker = "extra == 'delta-databricks'", specifier = ">=1.0.0" },
@@ -3135,6 +3164,7 @@ requires-dist = [
31353164
{ name = "ftfy" },
31363165
{ name = "imageio-ffmpeg" },
31373166
{ name = "liger-kernel", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'", specifier = ">=0.5.9" },
3167+
{ name = "mamba-ssm", marker = "extra == 'cuda'" },
31383168
{ name = "megatron-fsdp" },
31393169
{ name = "mistral-common", extras = ["opencv"], marker = "extra == 'vlm'" },
31403170
{ name = "mlflow" },
@@ -3216,6 +3246,32 @@ wheels = [
32163246
{ url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" },
32173247
]
32183248

3249+
[[package]]
3250+
name = "ninja"
3251+
version = "1.13.0"
3252+
source = { registry = "https://pypi.org/simple" }
3253+
sdist = { url = "https://files.pythonhosted.org/packages/43/73/79a0b22fc731989c708068427579e840a6cf4e937fe7ae5c5d0b7356ac22/ninja-1.13.0.tar.gz", hash = "sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978", size = 242558, upload-time = "2025-08-11T15:10:19.421Z" }
3254+
wheels = [
3255+
{ url = "https://files.pythonhosted.org/packages/3c/74/d02409ed2aa865e051b7edda22ad416a39d81a84980f544f8de717cab133/ninja-1.13.0-py3-none-macosx_10_9_universal2.whl", hash = "sha256:fa2a8bfc62e31b08f83127d1613d10821775a0eb334197154c4d6067b7068ff1", size = 310125, upload-time = "2025-08-11T15:09:50.971Z" },
3256+
{ url = "https://files.pythonhosted.org/packages/8e/de/6e1cd6b84b412ac1ef327b76f0641aeb5dcc01e9d3f9eee0286d0c34fd93/ninja-1.13.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630", size = 177467, upload-time = "2025-08-11T15:09:52.767Z" },
3257+
{ url = "https://files.pythonhosted.org/packages/c8/83/49320fb6e58ae3c079381e333575fdbcf1cca3506ee160a2dcce775046fa/ninja-1.13.0-py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c", size = 187834, upload-time = "2025-08-11T15:09:54.115Z" },
3258+
{ url = "https://files.pythonhosted.org/packages/56/c7/ba22748fb59f7f896b609cd3e568d28a0a367a6d953c24c461fe04fc4433/ninja-1.13.0-py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e", size = 202736, upload-time = "2025-08-11T15:09:55.745Z" },
3259+
{ url = "https://files.pythonhosted.org/packages/79/22/d1de07632b78ac8e6b785f41fa9aad7a978ec8c0a1bf15772def36d77aac/ninja-1.13.0-py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988", size = 179034, upload-time = "2025-08-11T15:09:57.394Z" },
3260+
{ url = "https://files.pythonhosted.org/packages/ed/de/0e6edf44d6a04dabd0318a519125ed0415ce437ad5a1ec9b9be03d9048cf/ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa", size = 180716, upload-time = "2025-08-11T15:09:58.696Z" },
3261+
{ url = "https://files.pythonhosted.org/packages/54/28/938b562f9057aaa4d6bfbeaa05e81899a47aebb3ba6751e36c027a7f5ff7/ninja-1.13.0-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1", size = 146843, upload-time = "2025-08-11T15:10:00.046Z" },
3262+
{ url = "https://files.pythonhosted.org/packages/2a/fb/d06a3838de4f8ab866e44ee52a797b5491df823901c54943b2adb0389fbb/ninja-1.13.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2", size = 154402, upload-time = "2025-08-11T15:10:01.657Z" },
3263+
{ url = "https://files.pythonhosted.org/packages/31/bf/0d7808af695ceddc763cf251b84a9892cd7f51622dc8b4c89d5012779f06/ninja-1.13.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f", size = 552388, upload-time = "2025-08-11T15:10:03.349Z" },
3264+
{ url = "https://files.pythonhosted.org/packages/9d/70/c99d0c2c809f992752453cce312848abb3b1607e56d4cd1b6cded317351a/ninja-1.13.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714", size = 472501, upload-time = "2025-08-11T15:10:04.735Z" },
3265+
{ url = "https://files.pythonhosted.org/packages/9f/43/c217b1153f0e499652f5e0766da8523ce3480f0a951039c7af115e224d55/ninja-1.13.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72", size = 638280, upload-time = "2025-08-11T15:10:06.512Z" },
3266+
{ url = "https://files.pythonhosted.org/packages/8c/45/9151bba2c8d0ae2b6260f71696330590de5850e5574b7b5694dce6023e20/ninja-1.13.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db", size = 642420, upload-time = "2025-08-11T15:10:08.35Z" },
3267+
{ url = "https://files.pythonhosted.org/packages/3c/fb/95752eb635bb8ad27d101d71bef15bc63049de23f299e312878fc21cb2da/ninja-1.13.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5", size = 585106, upload-time = "2025-08-11T15:10:09.818Z" },
3268+
{ url = "https://files.pythonhosted.org/packages/c1/31/aa56a1a286703800c0cbe39fb4e82811c277772dc8cd084f442dd8e2938a/ninja-1.13.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96", size = 707138, upload-time = "2025-08-11T15:10:11.366Z" },
3269+
{ url = "https://files.pythonhosted.org/packages/34/6f/5f5a54a1041af945130abdb2b8529cbef0cdcbbf9bcf3f4195378319d29a/ninja-1.13.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200", size = 581758, upload-time = "2025-08-11T15:10:13.295Z" },
3270+
{ url = "https://files.pythonhosted.org/packages/95/97/51359c77527d45943fe7a94d00a3843b81162e6c4244b3579fe8fc54cb9c/ninja-1.13.0-py3-none-win32.whl", hash = "sha256:8cfbb80b4a53456ae8a39f90ae3d7a2129f45ea164f43fadfa15dc38c4aef1c9", size = 267201, upload-time = "2025-08-11T15:10:15.158Z" },
3271+
{ url = "https://files.pythonhosted.org/packages/29/45/c0adfbfb0b5895aa18cec400c535b4f7ff3e52536e0403602fc1a23f7de9/ninja-1.13.0-py3-none-win_amd64.whl", hash = "sha256:fb8ee8719f8af47fed145cced4a85f0755dd55d45b2bddaf7431fa89803c5f3e", size = 309975, upload-time = "2025-08-11T15:10:16.697Z" },
3272+
{ url = "https://files.pythonhosted.org/packages/df/93/a7b983643d1253bb223234b5b226e69de6cda02b76cdca7770f684b795f5/ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9", size = 290806, upload-time = "2025-08-11T15:10:18.018Z" },
3273+
]
3274+
32193275
[[package]]
32203276
name = "nodeenv"
32213277
version = "1.9.1"

examples/llm_finetune/nemotron/nemotron_nano_v3_squad.yaml renamed to examples/llm_finetune/nemotron/nemotron_nano_v3_hellaswag.yaml

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,34 +35,31 @@ rng:
3535
seed: 1111
3636
ranked: true
3737

38+
parallelizer:
39+
_target_: nemo_automodel.components.moe.parallelizer.parallelize_model
40+
activation_checkpointing: false
41+
3842
model:
3943
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
4044
pretrained_model_name_or_path: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
4145
trust_remote_code: true
4246

43-
# torch.compile configuration
44-
compile:
45-
enabled: false
46-
mode: "default" # Options: "default", "reduce-overhead", "max-autotune"
47-
fullgraph: false
48-
dynamic: true # Set to false for better performance with fixed shapes
49-
backend: null # Use default backend (inductor)
50-
5147
distributed:
5248
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
5349
dp_size: none
5450
dp_replicate_size: 1 # dp_shard_size = dp_size / dp_replicate_size and dp_shard_size < dp_size. For DDP usecase, use DDPManager
5551
tp_size: 1
5652
cp_size: 1
53+
ep_size: 8
5754
sequence_parallel: false
5855
defer_fsdp_grad_sync: false
5956

6057
loss_fn:
6158
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
6259

6360
dataset:
64-
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
65-
dataset_name: rajpurkar/squad
61+
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
62+
path_or_dataset: rowan/hellaswag
6663
split: train
6764

6865
packed_sequence:
@@ -74,10 +71,10 @@ dataloader:
7471
shuffle: True
7572

7673
validation_dataset:
77-
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
78-
dataset_name: rajpurkar/squad
74+
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
75+
path_or_dataset: rowan/hellaswag
7976
split: validation
80-
limit_dataset_samples: 64
77+
num_samples_limit: 64
8178

8279
validation_dataloader:
8380
_target_: torchdata.stateful_dataloader.StatefulDataLoader

examples/llm_finetune/nemotron/nemotron_nano_v3_squad_peft.yaml renamed to examples/llm_finetune/nemotron/nemotron_nano_v3_hellaswag_peft.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ loss_fn:
6767
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
6868

6969
dataset:
70-
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
71-
dataset_name: rajpurkar/squad
70+
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
71+
path_or_dataset: rowan/hellaswag
7272
split: train
7373

7474
packed_sequence:
@@ -80,10 +80,10 @@ dataloader:
8080
shuffle: True
8181

8282
validation_dataset:
83-
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
84-
dataset_name: rajpurkar/squad
83+
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
84+
path_or_dataset: rowan/hellaswag
8585
split: validation
86-
limit_dataset_samples: 64
86+
num_samples_limit: 64
8787

8888
validation_dataloader:
8989
_target_: torchdata.stateful_dataloader.StatefulDataLoader

nemo_automodel/_transformers/auto_model.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,30 @@ def wrapper(self, *args, **kwargs):
131131
return obj
132132

133133

134+
def _is_config_compatible_with_custom_model(arch_name: str, config) -> bool:
135+
"""
136+
Check if a HuggingFace config is compatible with our custom model implementation.
137+
138+
Some architectures (e.g., NemotronHForCausalLM) are shared between different model versions
139+
(v2 vs v3) but our custom implementation only supports specific versions. This function
140+
validates that the config has the required attributes for the custom implementation.
141+
142+
Args:
143+
arch_name: The architecture name (e.g., "NemotronHForCausalLM")
144+
config: The HuggingFace config object
145+
146+
Returns:
147+
True if the config is compatible with our custom implementation, False otherwise
148+
"""
149+
# NemotronHForCausalLM: Our custom implementation is for v3 (MoE model)
150+
# v3 requires n_routed_experts, v2 does not have this attribute
151+
if arch_name == "NemotronHForCausalLM":
152+
return hasattr(config, "n_routed_experts") and config.n_routed_experts is not None
153+
154+
# All other architectures are assumed compatible
155+
return True
156+
157+
134158
def _patch_liger_kernel(model):
135159
"""
136160
Patches a model with liger-kernel and sdpa_kernel
@@ -503,7 +527,12 @@ def _retry(**override):
503527
)
504528
architectures = get_architectures(hf_config)
505529
# 2. If we have a custom model implementation available, we prioritize that over HF
506-
if len(architectures) > 0 and architectures[0] in ModelRegistry.model_arch_name_to_cls:
530+
arch_name = architectures[0] if len(architectures) > 0 else None
531+
if (
532+
arch_name is not None
533+
and arch_name in ModelRegistry.model_arch_name_to_cls
534+
and _is_config_compatible_with_custom_model(arch_name, hf_config)
535+
):
507536
# if we are able to init the custom model, we will now download the model weights on local rank 0
508537
_download_model_weights(hf_config, pretrained_model_name_or_path)
509538
logger.info(f"Using custom model implementation for {architectures[0]}")
@@ -673,7 +702,12 @@ def _retry(**override):
673702

674703
# 2. If we have a custom model implementation available, we prioritize that over HF
675704
architectures = get_architectures(config)
676-
if len(architectures) > 0 and architectures[0] in ModelRegistry.model_arch_name_to_cls:
705+
arch_name = architectures[0] if len(architectures) > 0 else None
706+
if (
707+
arch_name is not None
708+
and arch_name in ModelRegistry.model_arch_name_to_cls
709+
and _is_config_compatible_with_custom_model(arch_name, config)
710+
):
677711
model_cls = ModelRegistry.model_arch_name_to_cls[architectures[0]]
678712
init_param_names = _get_init_param_names(model_cls)
679713
_consume_config_overrides(config, kwargs, init_param_names=init_param_names)

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -358,15 +358,19 @@ def load_base_model(
358358
# But because we initialize on meta device, these are erroneously set to True.
359359
# We need to set them to False and call initialize_weights to re-initialize the weights.
360360

361-
# Gemma3ForConditionalGeneration cannot be pretrained currently. The pinned torch version
362-
# doesn't support initialize_weights when the model is sharded. This is because Gemma's
363-
# initialize_weights method requires setting a row to zeros in the embedding matrix.
364-
# This index selection op is not supported for DTensors in the pinned torch version.
361+
# Some models cannot call initialize_weights when sharded with DTensors:
362+
# - Gemma3ForConditionalGeneration: requires setting a row to zeros in the embedding matrix,
363+
# which is not supported for DTensors in the pinned torch version.
364+
# - NemotronHForCausalLM: the HF remote code's _init_weights uses dt_bias.copy_()
365+
# which fails with DTensors. Note: v3 (MoE) has n_routed_experts and uses our custom
366+
# implementation which handles this correctly.
365367
try:
366368
model_class = model.config.architectures[0]
367369
except:
368370
model_class = ""
369-
if model_class not in ["Gemma3ForConditionalGeneration", "NemotronHForCausalLM"]:
371+
is_nemotron_v2 = model_class == "NemotronHForCausalLM" and not getattr(model.config, "n_routed_experts", None)
372+
skip_initialize_weights = model_class in ["Gemma3ForConditionalGeneration"] or is_nemotron_v2
373+
if not skip_initialize_weights:
370374
for _, module in model.named_modules():
371375
if hasattr(module, "_is_hf_initialized"):
372376
module._is_hf_initialized = False
@@ -651,9 +655,15 @@ def _get_original_model_path(self, model_state: ModelState) -> str | None:
651655
"""
652656
Get the path to the original model from the Hugging Face checkpoint.
653657
"""
654-
if not hasattr(model_state.model[0], "name_or_path"):
658+
if not hasattr(model_state.model[0], "name_or_path") and not hasattr(
659+
getattr(model_state.model[0], "config", None), "name_or_path"
660+
):
655661
return None
656-
pretrained_model_name_or_path = getattr(model_state.model[0], "name_or_path")
662+
pretrained_model_name_or_path = getattr(model_state.model[0], "name_or_path", None) or getattr(
663+
getattr(model_state.model[0], "config", None), "name_or_path", None
664+
)
665+
if os.path.isdir(pretrained_model_name_or_path):
666+
return pretrained_model_name_or_path
657667
return get_safetensors_index_path(
658668
getattr(self.config, "original_model_root_dir", None) or TRANSFORMERS_CACHE, pretrained_model_name_or_path
659669
)

nemo_automodel/components/models/common/te_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def __init__(self, normalized_shape, eps, device, params_dtype):
8181
object.__setattr__(self, "te_norm", te_norm)
8282
object.__setattr__(self, "torch_norm", torch_norm)
8383

84+
def reset_parameters(self) -> None:
85+
"""Reset parameters by delegating to the underlying torch norm."""
86+
torch_norm = object.__getattribute__(self, "torch_norm")
87+
torch_norm.reset_parameters()
88+
8489
def forward(self, x: torch.Tensor) -> torch.Tensor:
8590
if is_tensor_unallocated(x):
8691
# Shape inference only - return empty tensor with same shape
@@ -113,6 +118,8 @@ def __init__(self, in_features, out_features, bias, device, params_dtype):
113118
if bias:
114119
self.bias = te_linear.bias
115120
torch_linear.bias = self.bias
121+
else:
122+
self.bias = None
116123

117124
# Use object.__setattr__ to prevent submodules from being registered in self._modules.
118125
object.__setattr__(self, "te_linear", te_linear)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)