Skip to content

Commit 33e86f6

Browse files
committed
Address PR review feedback
- Add shebang and NVIDIA copyright headers to gemma3_vl shell scripts - Fix grammar in scripts/training/README.md - Refactor run_recipe.py: use direct imports instead of importlib for step functions - Add error message constants and return type hints - Fix load_recipe to handle recipes without peft argument via signature inspection - Fix qwen3_vl.py: use _dataset_choice consistently for dataset selection logic
1 parent 60eaacd commit 33e86f6

File tree

7 files changed

+108
-26
lines changed

7 files changed

+108
-26
lines changed

examples/models/vlm/gemma3_vl/conversion.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
#!/usr/bin/env bash
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
# Workspace directory for checkpoints and results
217
WORKSPACE=${WORKSPACE:-/workspace}
318

examples/models/vlm/gemma3_vl/inference.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
#!/usr/bin/env bash
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
# Workspace directory for checkpoints and results
217
WORKSPACE=${WORKSPACE:-/workspace}
318

examples/models/vlm/gemma3_vl/peft.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
#!/usr/bin/env bash
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
# Workspace directory for checkpoints and results
217
WORKSPACE=${WORKSPACE:-/workspace}
318

examples/models/vlm/gemma3_vl/sft.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
#!/usr/bin/env bash
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
# Workspace directory for checkpoints and results
217
WORKSPACE=${WORKSPACE:-/workspace}
318

scripts/training/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,4 +251,4 @@ Generic scripts call recipes with no arguments passed to the recipe function.
251251

252252
All customization happens through CLI overrides after the config is built.
253253

254-
If you need to pass arguments to the recipe constructor itself (e.g., custom parallelism at recipe build time), use model-specific examples, create a custom script.
254+
If you need to pass arguments to the recipe constructor itself (e.g., custom parallelism at recipe build time), use model-specific examples or create a custom script.

scripts/training/run_recipe.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,37 @@
4848
"""
4949

5050
import argparse
51-
import importlib
51+
import inspect
52+
from typing import Callable
5253

5354
import megatron.bridge.recipes as recipes
5455
from megatron.bridge.training.config import ConfigContainer
5556
from megatron.bridge.training.finetune import finetune
57+
from megatron.bridge.training.gpt_step import forward_step as gpt_forward_step
58+
from megatron.bridge.training.llava_step import forward_step as llava_forward_step
5659
from megatron.bridge.training.pretrain import pretrain
5760
from megatron.bridge.training.utils.omegaconf_utils import process_config_with_overrides
61+
from megatron.bridge.training.vlm_step import forward_step as vlm_forward_step
5862

5963

60-
STEP_MODULES = {
61-
"gpt_step": "megatron.bridge.training.gpt_step",
62-
"vlm_step": "megatron.bridge.training.vlm_step",
63-
"llava_step": "megatron.bridge.training.llava_step",
64+
STEP_FUNCTIONS: dict[str, Callable] = {
65+
"gpt_step": gpt_forward_step,
66+
"vlm_step": vlm_forward_step,
67+
"llava_step": llava_forward_step,
6468
}
6569

6670
TRAIN_MODES = {
6771
"pretrain": pretrain,
6872
"finetune": finetune,
6973
}
7074

75+
# Error message constants
76+
ERR_UNKNOWN_STEP = "Unknown step type: {step_type}. Choose from: {choices}"
77+
ERR_INFER_MODE_FAILED = (
78+
"Unable to infer training mode from recipe name. "
79+
"Please include 'pretrain' or 'finetune' in the recipe name or pass --mode explicitly."
80+
)
81+
7182

7283
def parse_args() -> tuple[argparse.Namespace, list[str]]:
7384
"""Parse command-line arguments."""
@@ -92,7 +103,7 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]:
92103
"--step_func",
93104
type=str,
94105
default="gpt_step",
95-
choices=sorted(STEP_MODULES.keys()),
106+
choices=sorted(STEP_FUNCTIONS.keys()),
96107
help="Step function: gpt_step (text-only), vlm_step (vision-language), or llava_step (LLaVA models)",
97108
)
98109
parser.add_argument(
@@ -127,18 +138,32 @@ def load_recipe(recipe_name: str, peft_scheme: str | None) -> ConfigContainer:
127138
)
128139

129140
config_builder = getattr(recipes, recipe_name)
130-
return config_builder(peft=peft_scheme)
131-
132141

133-
def load_forward_step(step_type: str):
142+
# Check if the recipe accepts a 'peft' argument
143+
try:
144+
sig = inspect.signature(config_builder)
145+
params = sig.parameters
146+
accepts_peft = "peft" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())
147+
except (ValueError, TypeError):
148+
# If signature inspection fails, fall back to try/except
149+
accepts_peft = True
150+
151+
if accepts_peft:
152+
try:
153+
return config_builder(peft=peft_scheme)
154+
except TypeError:
155+
# Fallback if peft is not accepted despite signature inspection
156+
return config_builder()
157+
else:
158+
return config_builder()
159+
160+
161+
def load_forward_step(step_type: str) -> Callable:
134162
"""Load forward_step function based on the requested step type."""
135163
step_key = step_type.lower()
136-
if step_key not in STEP_MODULES:
137-
raise ValueError(f"Unknown step type: {step_type}. Choose from: {', '.join(STEP_MODULES)}")
138-
module = importlib.import_module(STEP_MODULES[step_key])
139-
if not hasattr(module, "forward_step"):
140-
raise AttributeError(f"{STEP_MODULES[step_key]} does not define forward_step")
141-
return module.forward_step
164+
if step_key not in STEP_FUNCTIONS:
165+
raise ValueError(ERR_UNKNOWN_STEP.format(step_type=step_type, choices=", ".join(STEP_FUNCTIONS)))
166+
return STEP_FUNCTIONS[step_key]
142167

143168

144169
def infer_train_mode(recipe_name: str) -> str:
@@ -148,10 +173,7 @@ def infer_train_mode(recipe_name: str) -> str:
148173
has_finetune = "finetune" in lowered
149174
if has_pretrain ^ has_finetune:
150175
return "pretrain" if has_pretrain else "finetune"
151-
raise ValueError(
152-
"Unable to infer training mode from recipe name. "
153-
"Please include 'pretrain' or 'finetune' in the recipe name or pass --mode explicitly."
154-
)
176+
raise ValueError(ERR_INFER_MODE_FAILED)
155177

156178

157179
def main() -> None:

src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,9 @@ def _qwen3_vl_common(
378378

379379
# Determine dataset selection strategy.
380380
_processor_model = tokenizer_model or hf_path
381-
mock = mock or dataset_type == "hf"
381+
_dataset_choice = dataset_type or ("mock" if mock else "hf")
382382

383-
if mock:
383+
if _dataset_choice == "mock":
384384
dataset_cfg: DatasetProvider = MockVLMConversationProvider(
385385
seq_length=seq_length,
386386
hf_processor_path=_processor_model,
@@ -393,7 +393,7 @@ def _qwen3_vl_common(
393393
create_attention_mask=True,
394394
pad_to_max_length=True,
395395
)
396-
elif dataset_type == "preloaded":
396+
elif _dataset_choice == "preloaded":
397397
dataset_cfg = PreloadedVLMConversationProvider(
398398
seq_length=seq_length,
399399
hf_processor_path=_processor_model,
@@ -407,7 +407,7 @@ def _qwen3_vl_common(
407407
pin_memory=True,
408408
persistent_workers=False,
409409
)
410-
elif dataset_type == "hf":
410+
elif _dataset_choice == "hf":
411411
dataset_cfg = HFDatasetConversationProvider(
412412
seq_length=seq_length,
413413
hf_processor_path=_processor_model,
@@ -418,7 +418,7 @@ def _qwen3_vl_common(
418418
pin_memory=True,
419419
persistent_workers=False,
420420
)
421-
elif dataset_type == "energon":
421+
elif _dataset_choice == "energon":
422422
tokenizer = AutoTokenizer.from_pretrained(_processor_model)
423423
# Use from_pretrained to ensure correct normalization (mean/std) and config (min_pixels)
424424
# matching Preloaded provider behavior.
@@ -441,7 +441,7 @@ def _qwen3_vl_common(
441441
)
442442
else:
443443
raise ValueError(
444-
f"Unsupported dataset_type '{dataset_type}'. Expected one of ['mock', 'preloaded', 'hf', 'energon']."
444+
f"Unsupported dataset_type '{_dataset_choice}'. Expected one of ['mock', 'preloaded', 'hf', 'energon']."
445445
)
446446
# Config Container
447447
cfg = ConfigContainer(

0 commit comments

Comments
 (0)