Conversation
remove useless log
|
/ok to test 2997516 |
📝 WalkthroughWalkthroughIntroduces distributed saving capabilities to HuggingFace model weight exports via two new parameters ( Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 1 | ❌ 3❌ Failed checks (2 warnings, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/hf_pretrained/state.py`:
- Around line 865-890: In _save_generator_distributed validate that the
save_every_n_ranks argument is an integer >= 1 before it’s used in
division/modulo; add an early check (e.g., at the top of
_save_generator_distributed) that raises a ValueError with a clear message if
save_every_n_ranks is <= 0 or not an int, so subsequent calculations of
num_nodes, is_saver_rank, saver_ranks and saver_index are safe.
In `@tests/functional_tests/training/test_distributed_save_hf_weights.py`:
- Around line 221-225: Replace the flaky assertion with an explicit exception:
where the code checks mismatched_weights (the block that logs via logger.warning
and iterates over mismatched_weights), remove "assert False, f'Weight mismatch
detected. Max difference: {max_diff}'" and instead raise an AssertionError with
the same message (e.g. raise AssertionError(f"Weight mismatch detected. Max
difference: {max_diff}")). This ensures the test always fails even when Python
is run with -O optimizations.
- Around line 135-240: Replace the broad "except Exception as e" in the
save/reload block with targeted exception handling and add precondition skips:
at the start of the test (before the complex setup/try) check
torch.cuda.is_available() and skip if false, and check
torch.distributed.is_initialized() or expected world size and skip if the
distributed environment is not present; then restrict the except clause around
the save/load sequence to only environment-related errors (e.g., except
(RuntimeError, OSError) as e) so real test failures (assertions, logic errors)
are not swallowed; also update the test docstring to state that CUDA and a
distributed environment are required.
🧹 Nitpick comments (6)
tests/functional_tests/training/test_distributed_save_hf_weights.py (3)
41-69: Rename globals to the requiredG_prefix.
loggerandHF_QWEN2_TOY_MODEL_CONFIGare module-level globals and should follow the mandatedG_prefix convention.Suggested rename
-logger = logging.getLogger(__name__) +G_LOGGER = logging.getLogger(__name__) -HF_QWEN2_TOY_MODEL_CONFIG = { +G_HF_QWEN2_TOY_MODEL_CONFIG = { @@ - config = Qwen2Config(**HF_QWEN2_TOY_MODEL_CONFIG) + config = Qwen2Config(**G_HF_QWEN2_TOY_MODEL_CONFIG) @@ - logger.info(f" {item.name} {item.is_file()}") + G_LOGGER.info(f" {item.name} {item.is_file()}")As per coding guidelines: Use upper snake_case and prefix 'G' for global variables (e.g., G_MY_GLOBAL).
72-106: Add type hints + Google-style docstring toinit_parallel_state.The helper is missing required annotations and a docstring.
Suggested update
-def init_parallel_state(tp_size, pp_size): +def init_parallel_state(tp_size: int, pp_size: int) -> None: + """Initialize torch.distributed and Megatron model-parallel state. + + Args: + tp_size: Tensor model-parallel size. + pp_size: Pipeline model-parallel size. + """As per coding guidelines: Use Google style docstrings (parseable by Sphinx) for classes and functions; Use type hints for function arguments and return types.
242-248: Tear down distributed state infinallyand narrow cleanup exceptions.Leaving the process group and model-parallel state initialized can contaminate subsequent tests.
Suggested cleanup
- finally: - if torch.distributed.get_rank() == 0 and output_path.exists(): - try: - shutil.rmtree(os.path.dirname(output_path)) - logger.info(f"Successfully cleaned up temporary directory: {output_path}") - except Exception as e: - logger.warning(f"Failed to clean up temporary directory {output_path}: {e}") + finally: + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + if dist.is_available() and dist.is_initialized(): + parallel_state.destroy_model_parallel() + dist.destroy_process_group() + if rank == 0 and output_path.exists(): + try: + shutil.rmtree(os.path.dirname(output_path)) + logger.info(f"Successfully cleaned up temporary directory: {output_path}") + except OSError as e: + logger.warning(f"Failed to clean up temporary directory {output_path}: {e}")As per coding guidelines: When using try-except blocks, limit the except clause to the smallest set of errors possible.
src/megatron/bridge/models/hf_pretrained/state.py (3)
23-31: Use built-in generics for the new annotations.The new code introduces
typing.Set/Dictusage; prefer built-inset/dictper style guidelines.Suggested refactor
-from typing import ( - Dict, - Iterable, - List, - Optional, - Pattern, - Set, - Tuple, - Union, - overload, -) +from typing import ( + Dict, + Iterable, + List, + Optional, + Pattern, + Tuple, + Union, + overload, +) @@ - all_expected_keys: Set[str] = set(key_to_filename_map.keys()) + all_expected_keys: set[str] = set(key_to_filename_map.keys()) - filename_to_keys_map: Dict[str, Set[str]] = defaultdict(set) + filename_to_keys_map: dict[str, set[str]] = defaultdict(set) @@ - assigned_expected_keys: Set[str] = ( + assigned_expected_keys: set[str] = ( @@ - buffered_tensors: Dict[str, torch.Tensor] = {} - actually_saved_keys: Set[str] = set() + buffered_tensors: dict[str, torch.Tensor] = {} + actually_saved_keys: set[str] = set()As per coding guidelines: Use built-in generics (list, dict, tuple) instead of typing equivalents.
Also applies to: 912-935
940-960: Use rank-aware logging instead of bareThese prints will be emitted by multiple ranks; use
print_rank_0(or an equivalent rank-0 guard) to avoid noisy duplicates.Suggested change
- print(f"Warning: tensor '{name}' from generator not found in original model structure. Skipping.") + print_rank_0( + f"Warning: tensor '{name}' from generator not found in original model structure. Skipping." + ) @@ - print(f"Rank {rank}: Missing tensors for keys: {missing_str}", flush=True) + print_rank_0(f"Rank {rank}: Missing tensors for keys: {missing_str}")As per coding guidelines: Use 'print_rank_0' for logging in model bridge to avoid duplicate output across ranks.
865-871: Add a Google-style docstring and return type for_save_generator_distributed.The new helper lacks the required documentation and explicit return type.
Suggested docstring
- def _save_generator_distributed( + def _save_generator_distributed( self, generator: Iterable[Tuple[str, torch.Tensor]], output_path: Union[str, Path], strict: bool = True, save_every_n_ranks: int = 1, - ): + ) -> None: + """Save tensors across ranks in distributed mode. + + Args: + generator: Iterable of (tensor_name, tensor) tuples. + output_path: Directory to write shard files into. + strict: Whether to error on unexpected tensor names. + save_every_n_ranks: Interval for saver ranks. + """As per coding guidelines: Use Google style docstrings (parseable by Sphinx) for classes and functions; Use type hints for function arguments and return types.
| def _save_generator_distributed( | ||
| self, | ||
| generator: Iterable[Tuple[str, torch.Tensor]], | ||
| output_path: Union[str, Path], | ||
| strict: bool = True, | ||
| save_every_n_ranks: int = 1, | ||
| ): | ||
| is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() | ||
| if is_distributed: | ||
| world_size = torch.distributed.get_world_size() | ||
| rank = torch.distributed.get_rank() | ||
| else: | ||
| world_size = 1 | ||
| rank = 0 | ||
|
|
||
| from safetensors.torch import save_file | ||
|
|
||
| output_path = Path(output_path) | ||
|
|
||
| # Calculate which ranks should participate in saving | ||
| # Only rank % save_every_n_ranks == 0 will save | ||
| num_nodes = (world_size + save_every_n_ranks - 1) // save_every_n_ranks | ||
| is_saver_rank = rank % save_every_n_ranks == 0 | ||
| saver_ranks = [i * save_every_n_ranks for i in range(num_nodes) if i * save_every_n_ranks < world_size] | ||
| num_savers = len(saver_ranks) | ||
| saver_index = rank // save_every_n_ranks if is_saver_rank else -1 |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, verify the file exists and examine the relevant method
find . -type f -name "state.py" | grep -E "bridge/models/hf_pretrained"Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 122
🏁 Script executed:
# Once found, let's look at the method in detail
cat -n src/megatron/bridge/models/hf_pretrained/state.py | sed -n '865,920p'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2631
🏁 Script executed:
# Let's also check the entire method to understand its context better
cat -n src/megatron/bridge/models/hf_pretrained/state.py | sed -n '855,900p'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2163
🏁 Script executed:
# Check if there's any existing validation for save_every_n_ranks in the file
rg "save_every_n_ranks" src/megatron/bridge/models/hf_pretrained/state.py -A 2 -B 2Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1420
🏁 Script executed:
# Check how this method is called throughout the codebase
rg "_save_generator_distributed" --type py -A 3 -B 3Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1167
Validate save_every_n_ranks >= 1 before using in division and modulo operations.
Lines 886–890 use save_every_n_ranks in division (//) and modulo (%) operations. Zero or negative values will cause ZeroDivisionError at runtime.
Suggested fix
def _save_generator_distributed(
self,
generator: Iterable[Tuple[str, torch.Tensor]],
output_path: Union[str, Path],
strict: bool = True,
save_every_n_ranks: int = 1,
):
+ if save_every_n_ranks < 1:
+ raise ValueError("save_every_n_ranks must be >= 1")
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/hf_pretrained/state.py` around lines 865 - 890, In
_save_generator_distributed validate that the save_every_n_ranks argument is an
integer >= 1 before it’s used in division/modulo; add an early check (e.g., at
the top of _save_generator_distributed) that raises a ValueError with a clear
message if save_every_n_ranks is <= 0 or not an int, so subsequent calculations
of num_nodes, is_saver_rank, saver_ranks and saver_index are safe.
| try: | ||
| bridge = AutoBridge.from_hf_pretrained( | ||
| toy_dir, | ||
| trust_remote_code=True, | ||
| ) | ||
|
|
||
| provider = bridge.to_megatron_provider() | ||
| provider.tensor_model_parallel_size = tp_size | ||
| provider.pipeline_model_parallel_size = pp_size | ||
| provider.finalize() | ||
|
|
||
| model = provider.provide_distributed_model(wrap_with_ddp=False) | ||
|
|
||
| torch.cuda.synchronize() | ||
| before_save = time.time() | ||
| bridge.save_hf_weights( | ||
| model, | ||
| str(output_path), | ||
| merge_adapter_weights=True, | ||
| distributed_save=distributed_save, | ||
| save_every_n_ranks=save_every_n_ranks, | ||
| ) | ||
| torch.distributed.barrier() | ||
| torch.cuda.synchronize() | ||
| after_save = time.time() | ||
|
|
||
| assert output_path.exists(), f"Output directory {output_path} was not created" | ||
| if torch.distributed.get_rank() == 0: | ||
| for item in output_path.iterdir(): | ||
| logger.info(f" {item.name} {item.is_file()}") | ||
|
|
||
| weight_files = list(output_path.glob("model*.safetensors")) or list( | ||
| output_path.glob("pytorch_model*.bin") | ||
| ) | ||
| assert len(weight_files) > 0, "No model weight files found in output directory" | ||
|
|
||
| shutil.copy(Path(toy_dir) / "config.json", output_path / "config.json") | ||
|
|
||
| reloaded_model = AutoModelForCausalLM.from_pretrained( | ||
| str(output_path), | ||
| device_map="cpu", | ||
| trust_remote_code=True, | ||
| ) | ||
| assert reloaded_model is not None, "Failed to load model from saved checkpoint" | ||
|
|
||
| assert hasattr(reloaded_model, "model"), "Reloaded model missing 'model' attribute" | ||
| assert hasattr(reloaded_model.model, "layers"), "Reloaded model missing 'layers' attribute" | ||
|
|
||
| # Compare weights between toy_model and reloaded_model | ||
| toy_model_cpu = toy_model.cpu() | ||
| toy_state_dict = toy_model_cpu.state_dict() | ||
| reloaded_state_dict = reloaded_model.state_dict() | ||
|
|
||
| # Check if all keys match | ||
| toy_keys = set(toy_state_dict.keys()) | ||
| reloaded_keys = set(reloaded_state_dict.keys()) | ||
|
|
||
| missing_keys = toy_keys - reloaded_keys | ||
| extra_keys = reloaded_keys - toy_keys | ||
|
|
||
| if missing_keys: | ||
| logger.warning(f"Missing keys in reloaded model: {missing_keys}") | ||
| if extra_keys: | ||
| logger.warning(f"Extra keys in reloaded model: {extra_keys}") | ||
|
|
||
| assert toy_keys == reloaded_keys, f"Key mismatch: missing={missing_keys}, extra={extra_keys}" | ||
|
|
||
| # Compare weight values | ||
| max_diff = 0.0 | ||
| mismatched_weights = [] | ||
| for key in toy_keys: | ||
| toy_weight = toy_state_dict[key] | ||
| reloaded_weight = reloaded_state_dict[key] | ||
|
|
||
| # Convert to same dtype for comparison | ||
| if toy_weight.dtype != reloaded_weight.dtype: | ||
| reloaded_weight = reloaded_weight.to(toy_weight.dtype) | ||
|
|
||
| diff = torch.abs(toy_weight - reloaded_weight).max().item() | ||
| if diff > max_diff: | ||
| max_diff = diff | ||
|
|
||
| # Allow small numerical differences (e.g., 1e-5) | ||
| if diff > 1e-5: | ||
| mismatched_weights.append((key, diff)) | ||
|
|
||
| if mismatched_weights: | ||
| logger.warning(f"Found {len(mismatched_weights)} mismatched weights:") | ||
| for key, diff in mismatched_weights[:10]: # Print first 10 | ||
| logger.warning(f" {key}: max_diff={diff}") | ||
| assert False, f"Weight mismatch detected. Max difference: {max_diff}" | ||
|
|
||
| logger.info(f"Weight comparison passed! Max difference: {max_diff:.2e}") | ||
|
|
||
| logger.info( | ||
| f"Distributed_save test passed: Model successfully saved to {output_path} using time {(after_save - before_save):.2f}s" | ||
| ) | ||
| logger.info(f" - Weight files: {len(weight_files)} file(s)") | ||
| logger.info(" - Model successfully reloaded and validated") | ||
|
|
||
| except Exception as e: | ||
| import traceback | ||
|
|
||
| traceback.print_exc() | ||
| logger.error(f"Distributed_save test skipped due to: {e}") | ||
| pytest.skip(f"Distributed_save test skipped due to: {e}") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find and read the test file
find . -name "test_distributed_save_hf_weights.py" -type fRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 140
🏁 Script executed:
# Check the file size to understand how much to read
wc -l tests/functional_tests/training/test_distributed_save_hf_weights.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 142
🏁 Script executed:
# Read lines around 135-240 to verify the code snippet
sed -n '130,245p' tests/functional_tests/training/test_distributed_save_hf_weights.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 5115
🏁 Script executed:
# Also check if there are explicit precondition checks elsewhere in the test or module
grep -n "pytest.skip\|torch.cuda.is_available\|torch.distributed" tests/functional_tests/training/test_distributed_save_hf_weights.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 447
🏁 Script executed:
# Check the entire function/test to understand context better
sed -n '100,250p' tests/functional_tests/training/test_distributed_save_hf_weights.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 6427
Catch specific exception types instead of broad Exception to avoid masking real failures.
The broad exception handler masks legitimate failures that should fail the test. For distributed training tests, limit exception handling to the specific types that can occur from expected environment issues:
Suggested fix
- except Exception as e:
- import traceback
-
- traceback.print_exc()
- logger.error(f"Distributed_save test skipped due to: {e}")
- pytest.skip(f"Distributed_save test skipped due to: {e}")
+ except (RuntimeError, OSError, FileNotFoundError) as e:
+ logger.exception("Distributed_save test skipped due to: %s", e)
+ pytest.skip(f"Distributed_save test skipped due to: {e}")Consider adding precondition checks at the start of the test to skip before attempting complex setup:
if not torch.cuda.is_available():
pytest.skip("CUDA is required for this test")As per coding guidelines: "When using try-except blocks, limit the except clause to the smallest set of errors possible." Also document GPU requirements in the test docstring.
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 225-225: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
[warning] 235-235: Do not catch blind exception: Exception
(BLE001)
[warning] 239-239: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
🤖 Prompt for AI Agents
In `@tests/functional_tests/training/test_distributed_save_hf_weights.py` around
lines 135 - 240, Replace the broad "except Exception as e" in the save/reload
block with targeted exception handling and add precondition skips: at the start
of the test (before the complex setup/try) check torch.cuda.is_available() and
skip if false, and check torch.distributed.is_initialized() or expected world
size and skip if the distributed environment is not present; then restrict the
except clause around the save/load sequence to only environment-related errors
(e.g., except (RuntimeError, OSError) as e) so real test failures (assertions,
logic errors) are not swallowed; also update the test docstring to state that
CUDA and a distributed environment are required.
| if mismatched_weights: | ||
| logger.warning(f"Found {len(mismatched_weights)} mismatched weights:") | ||
| for key, diff in mismatched_weights[:10]: # Print first 10 | ||
| logger.warning(f" {key}: max_diff={diff}") | ||
| assert False, f"Weight mismatch detected. Max difference: {max_diff}" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n tests/functional_tests/training/test_distributed_save_hf_weights.py | sed -n '210,230p'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1321
🏁 Script executed:
# Let's also check what's before to understand the context
cat -n tests/functional_tests/training/test_distributed_save_hf_weights.py | sed -n '200,235p'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2066
🏁 Script executed:
# Check if there are similar patterns in other test files
rg "assert False" tests/ -A 1 -B 1Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 25300
🏁 Script executed:
# Check the coding guidelines for any specific guidance on assertions in tests
grep -i "assert" CODING_GUIDELINES.md 2>/dev/null || echo "File not found or no matches"Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 99
Replace assert False with an explicit exception.
assert statements can be stripped with the -O optimization flag; use raise AssertionError to ensure failure behavior is preserved.
Suggested fix
- assert False, f"Weight mismatch detected. Max difference: {max_diff}"
+ raise AssertionError(f"Weight mismatch detected. Max difference: {max_diff}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if mismatched_weights: | |
| logger.warning(f"Found {len(mismatched_weights)} mismatched weights:") | |
| for key, diff in mismatched_weights[:10]: # Print first 10 | |
| logger.warning(f" {key}: max_diff={diff}") | |
| assert False, f"Weight mismatch detected. Max difference: {max_diff}" | |
| if mismatched_weights: | |
| logger.warning(f"Found {len(mismatched_weights)} mismatched weights:") | |
| for key, diff in mismatched_weights[:10]: # Print first 10 | |
| logger.warning(f" {key}: max_diff={diff}") | |
| raise AssertionError(f"Weight mismatch detected. Max difference: {max_diff}") |
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 225-225: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
🤖 Prompt for AI Agents
In `@tests/functional_tests/training/test_distributed_save_hf_weights.py` around
lines 221 - 225, Replace the flaky assertion with an explicit exception: where
the code checks mismatched_weights (the block that logs via logger.warning and
iterates over mismatched_weights), remove "assert False, f'Weight mismatch
detected. Max difference: {max_diff}'" and instead raise an AssertionError with
the same message (e.g. raise AssertionError(f"Weight mismatch detected. Max
difference: {max_diff}")). This ensures the test always fails even when Python
is run with -O optimizations.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
Release Notes
New Features
distributed_saveoption to control distributed weight saving behavior across processessave_every_n_ranksparameter provides flexible control over save frequency and rank distributionTests