Skip to content

ci: 1806#2230

Open
chtruong814 wants to merge 11 commits intomainfrom
chtruong/ci-1806
Open

ci: 1806#2230
chtruong814 wants to merge 11 commits intomainfrom
chtruong/ci-1806

Conversation

@chtruong814
Copy link
Contributor

@chtruong814 chtruong814 commented Feb 5, 2026

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

Release Notes

  • New Features

    • Added distributed saving capability for HuggingFace model exports, enabling efficient parallel weight saving across multiple GPU/TPU ranks
    • Introduced distributed_save option to control distributed weight saving behavior across processes
    • New save_every_n_ranks parameter provides flexible control over save frequency and rank distribution
  • Tests

    • Added comprehensive functional tests validating distributed model weight saving and checkpoint integrity

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@chtruong814
Copy link
Contributor Author

/ok to test 2997516

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

Introduces distributed saving capabilities to HuggingFace model weight exports via two new parameters (distributed_save and save_every_n_ranks) to save_hf_pretrained and save_hf_weights. When enabled, tensors are distributed across multiple ranks with synchronization barriers and individual shard files per rank. Default behavior (disabled) preserves existing rank-0-only saving semantics.

Changes

Cohort / File(s) Summary
Core Distributed Saving Implementation
src/megatron/bridge/models/conversion/auto_bridge.py
Added distributed_save and save_every_n_ranks parameters to save_hf_pretrained and save_hf_weights methods. Parameters are propagated through the call chain from save_hf_pretrained to save_hf_weights to SafeTensorsStateSource.save_generator. Expanded docstrings to document distributed saving semantics.
Distributed Save Backend
src/megatron/bridge/models/hf_pretrained/state.py
Extended save_generator method signature to accept distributed_save and save_every_n_ranks parameters. Added new internal _save_generator_distributed method implementing multi-rank shard allocation, buffering, synchronized barrier coordination before/after saves, and distributed file writing. Includes strict-mode error handling and post-save index.json rebuilding. Added Set type import for typing support.
Functional Integration Test
tests/functional_tests/training/test_distributed_save_hf_weights.py
New comprehensive test file validating distributed HF weight saving workflow: initializes distributed environment, constructs toy Qwen2 model, loads via AutoBridge, invokes distributed save, and validates output (artifact presence, reloaded model state dict equality with numerical tolerance ~1e-5). Includes error handling with pytest skips and temporary cleanup.
Unit Test Updates
tests/unit_tests/models/test_auto_bridge.py
Updated test expectations to include new distributed_save and save_every_n_ranks keyword arguments in all save_hf_weights call site assertions. Preserves existing argument order with backward-compatible defaults.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested reviewers

  • suiyoubi
  • thomasdhc
🚥 Pre-merge checks | ✅ 1 | ❌ 3
❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning Performance-related PR lacks benchmarking results, before-and-after metrics, or evidence demonstrating that distributed saving improves performance over non-distributed approach. Add performance benchmarking results comparing distributed_save=False vs True across different rank configurations with configuration details and timing measurements.
Title check ❓ Inconclusive The PR title 'ci: 1806' is vague and generic, using only a ticket identifier without describing the actual changes made to the codebase. Revise the title to clearly describe the main change, such as 'Add distributed saving capabilities to HuggingFace weight export' or 'Implement distributed save feature for HF weights'.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch chtruong/ci-1806

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/hf_pretrained/state.py`:
- Around line 865-890: In _save_generator_distributed validate that the
save_every_n_ranks argument is an integer >= 1 before it’s used in
division/modulo; add an early check (e.g., at the top of
_save_generator_distributed) that raises a ValueError with a clear message if
save_every_n_ranks is <= 0 or not an int, so subsequent calculations of
num_nodes, is_saver_rank, saver_ranks and saver_index are safe.

In `@tests/functional_tests/training/test_distributed_save_hf_weights.py`:
- Around line 221-225: Replace the flaky assertion with an explicit exception:
where the code checks mismatched_weights (the block that logs via logger.warning
and iterates over mismatched_weights), remove "assert False, f'Weight mismatch
detected. Max difference: {max_diff}'" and instead raise an AssertionError with
the same message (e.g. raise AssertionError(f"Weight mismatch detected. Max
difference: {max_diff}")). This ensures the test always fails even when Python
is run with -O optimizations.
- Around line 135-240: Replace the broad "except Exception as e" in the
save/reload block with targeted exception handling and add precondition skips:
at the start of the test (before the complex setup/try) check
torch.cuda.is_available() and skip if false, and check
torch.distributed.is_initialized() or expected world size and skip if the
distributed environment is not present; then restrict the except clause around
the save/load sequence to only environment-related errors (e.g., except
(RuntimeError, OSError) as e) so real test failures (assertions, logic errors)
are not swallowed; also update the test docstring to state that CUDA and a
distributed environment are required.
🧹 Nitpick comments (6)
tests/functional_tests/training/test_distributed_save_hf_weights.py (3)

41-69: Rename globals to the required G_ prefix.

logger and HF_QWEN2_TOY_MODEL_CONFIG are module-level globals and should follow the mandated G_ prefix convention.

Suggested rename
-logger = logging.getLogger(__name__)
+G_LOGGER = logging.getLogger(__name__)

-HF_QWEN2_TOY_MODEL_CONFIG = {
+G_HF_QWEN2_TOY_MODEL_CONFIG = {
@@
-        config = Qwen2Config(**HF_QWEN2_TOY_MODEL_CONFIG)
+        config = Qwen2Config(**G_HF_QWEN2_TOY_MODEL_CONFIG)
@@
-                    logger.info(f"  {item.name} {item.is_file()}")
+                    G_LOGGER.info(f"  {item.name} {item.is_file()}")

As per coding guidelines: Use upper snake_case and prefix 'G' for global variables (e.g., G_MY_GLOBAL).


72-106: Add type hints + Google-style docstring to init_parallel_state.

The helper is missing required annotations and a docstring.

Suggested update
-def init_parallel_state(tp_size, pp_size):
+def init_parallel_state(tp_size: int, pp_size: int) -> None:
+    """Initialize torch.distributed and Megatron model-parallel state.
+
+    Args:
+        tp_size: Tensor model-parallel size.
+        pp_size: Pipeline model-parallel size.
+    """

As per coding guidelines: Use Google style docstrings (parseable by Sphinx) for classes and functions; Use type hints for function arguments and return types.


242-248: Tear down distributed state in finally and narrow cleanup exceptions.

Leaving the process group and model-parallel state initialized can contaminate subsequent tests.

Suggested cleanup
-        finally:
-            if torch.distributed.get_rank() == 0 and output_path.exists():
-                try:
-                    shutil.rmtree(os.path.dirname(output_path))
-                    logger.info(f"Successfully cleaned up temporary directory: {output_path}")
-                except Exception as e:
-                    logger.warning(f"Failed to clean up temporary directory {output_path}: {e}")
+        finally:
+            rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
+            if dist.is_available() and dist.is_initialized():
+                parallel_state.destroy_model_parallel()
+                dist.destroy_process_group()
+            if rank == 0 and output_path.exists():
+                try:
+                    shutil.rmtree(os.path.dirname(output_path))
+                    logger.info(f"Successfully cleaned up temporary directory: {output_path}")
+                except OSError as e:
+                    logger.warning(f"Failed to clean up temporary directory {output_path}: {e}")

As per coding guidelines: When using try-except blocks, limit the except clause to the smallest set of errors possible.

src/megatron/bridge/models/hf_pretrained/state.py (3)

23-31: Use built-in generics for the new annotations.

The new code introduces typing.Set/Dict usage; prefer built-in set/dict per style guidelines.

Suggested refactor
-from typing import (
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Pattern,
-    Set,
-    Tuple,
-    Union,
-    overload,
-)
+from typing import (
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Pattern,
+    Tuple,
+    Union,
+    overload,
+)
@@
-        all_expected_keys: Set[str] = set(key_to_filename_map.keys())
+        all_expected_keys: set[str] = set(key_to_filename_map.keys())
-        filename_to_keys_map: Dict[str, Set[str]] = defaultdict(set)
+        filename_to_keys_map: dict[str, set[str]] = defaultdict(set)
@@
-            assigned_expected_keys: Set[str] = (
+            assigned_expected_keys: set[str] = (
@@
-        buffered_tensors: Dict[str, torch.Tensor] = {}
-        actually_saved_keys: Set[str] = set()
+        buffered_tensors: dict[str, torch.Tensor] = {}
+        actually_saved_keys: set[str] = set()

As per coding guidelines: Use built-in generics (list, dict, tuple) instead of typing equivalents.

Also applies to: 912-935


940-960: Use rank-aware logging instead of bare print in distributed saves.

These prints will be emitted by multiple ranks; use print_rank_0 (or an equivalent rank-0 guard) to avoid noisy duplicates.

Suggested change
-                    print(f"Warning: tensor '{name}' from generator not found in original model structure. Skipping.")
+                    print_rank_0(
+                        f"Warning: tensor '{name}' from generator not found in original model structure. Skipping."
+                    )
@@
-                print(f"Rank {rank}: Missing tensors for keys: {missing_str}", flush=True)
+                print_rank_0(f"Rank {rank}: Missing tensors for keys: {missing_str}")

As per coding guidelines: Use 'print_rank_0' for logging in model bridge to avoid duplicate output across ranks.


865-871: Add a Google-style docstring and return type for _save_generator_distributed.

The new helper lacks the required documentation and explicit return type.

Suggested docstring
-    def _save_generator_distributed(
+    def _save_generator_distributed(
         self,
         generator: Iterable[Tuple[str, torch.Tensor]],
         output_path: Union[str, Path],
         strict: bool = True,
         save_every_n_ranks: int = 1,
-    ):
+    ) -> None:
+        """Save tensors across ranks in distributed mode.
+
+        Args:
+            generator: Iterable of (tensor_name, tensor) tuples.
+            output_path: Directory to write shard files into.
+            strict: Whether to error on unexpected tensor names.
+            save_every_n_ranks: Interval for saver ranks.
+        """

As per coding guidelines: Use Google style docstrings (parseable by Sphinx) for classes and functions; Use type hints for function arguments and return types.

Comment on lines +865 to +890
def _save_generator_distributed(
self,
generator: Iterable[Tuple[str, torch.Tensor]],
output_path: Union[str, Path],
strict: bool = True,
save_every_n_ranks: int = 1,
):
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
if is_distributed:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
else:
world_size = 1
rank = 0

from safetensors.torch import save_file

output_path = Path(output_path)

# Calculate which ranks should participate in saving
# Only rank % save_every_n_ranks == 0 will save
num_nodes = (world_size + save_every_n_ranks - 1) // save_every_n_ranks
is_saver_rank = rank % save_every_n_ranks == 0
saver_ranks = [i * save_every_n_ranks for i in range(num_nodes) if i * save_every_n_ranks < world_size]
num_savers = len(saver_ranks)
saver_index = rank // save_every_n_ranks if is_saver_rank else -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, verify the file exists and examine the relevant method
find . -type f -name "state.py" | grep -E "bridge/models/hf_pretrained"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 122


🏁 Script executed:

# Once found, let's look at the method in detail
cat -n src/megatron/bridge/models/hf_pretrained/state.py | sed -n '865,920p'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2631


🏁 Script executed:

# Let's also check the entire method to understand its context better
cat -n src/megatron/bridge/models/hf_pretrained/state.py | sed -n '855,900p'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2163


🏁 Script executed:

# Check if there's any existing validation for save_every_n_ranks in the file
rg "save_every_n_ranks" src/megatron/bridge/models/hf_pretrained/state.py -A 2 -B 2

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1420


🏁 Script executed:

# Check how this method is called throughout the codebase
rg "_save_generator_distributed" --type py -A 3 -B 3

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1167


Validate save_every_n_ranks >= 1 before using in division and modulo operations.

Lines 886–890 use save_every_n_ranks in division (//) and modulo (%) operations. Zero or negative values will cause ZeroDivisionError at runtime.

Suggested fix
     def _save_generator_distributed(
         self,
         generator: Iterable[Tuple[str, torch.Tensor]],
         output_path: Union[str, Path],
         strict: bool = True,
         save_every_n_ranks: int = 1,
     ):
+        if save_every_n_ranks < 1:
+            raise ValueError("save_every_n_ranks must be >= 1")
         is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/hf_pretrained/state.py` around lines 865 - 890, In
_save_generator_distributed validate that the save_every_n_ranks argument is an
integer >= 1 before it’s used in division/modulo; add an early check (e.g., at
the top of _save_generator_distributed) that raises a ValueError with a clear
message if save_every_n_ranks is <= 0 or not an int, so subsequent calculations
of num_nodes, is_saver_rank, saver_ranks and saver_index are safe.

Comment on lines +135 to +240
try:
bridge = AutoBridge.from_hf_pretrained(
toy_dir,
trust_remote_code=True,
)

provider = bridge.to_megatron_provider()
provider.tensor_model_parallel_size = tp_size
provider.pipeline_model_parallel_size = pp_size
provider.finalize()

model = provider.provide_distributed_model(wrap_with_ddp=False)

torch.cuda.synchronize()
before_save = time.time()
bridge.save_hf_weights(
model,
str(output_path),
merge_adapter_weights=True,
distributed_save=distributed_save,
save_every_n_ranks=save_every_n_ranks,
)
torch.distributed.barrier()
torch.cuda.synchronize()
after_save = time.time()

assert output_path.exists(), f"Output directory {output_path} was not created"
if torch.distributed.get_rank() == 0:
for item in output_path.iterdir():
logger.info(f" {item.name} {item.is_file()}")

weight_files = list(output_path.glob("model*.safetensors")) or list(
output_path.glob("pytorch_model*.bin")
)
assert len(weight_files) > 0, "No model weight files found in output directory"

shutil.copy(Path(toy_dir) / "config.json", output_path / "config.json")

reloaded_model = AutoModelForCausalLM.from_pretrained(
str(output_path),
device_map="cpu",
trust_remote_code=True,
)
assert reloaded_model is not None, "Failed to load model from saved checkpoint"

assert hasattr(reloaded_model, "model"), "Reloaded model missing 'model' attribute"
assert hasattr(reloaded_model.model, "layers"), "Reloaded model missing 'layers' attribute"

# Compare weights between toy_model and reloaded_model
toy_model_cpu = toy_model.cpu()
toy_state_dict = toy_model_cpu.state_dict()
reloaded_state_dict = reloaded_model.state_dict()

# Check if all keys match
toy_keys = set(toy_state_dict.keys())
reloaded_keys = set(reloaded_state_dict.keys())

missing_keys = toy_keys - reloaded_keys
extra_keys = reloaded_keys - toy_keys

if missing_keys:
logger.warning(f"Missing keys in reloaded model: {missing_keys}")
if extra_keys:
logger.warning(f"Extra keys in reloaded model: {extra_keys}")

assert toy_keys == reloaded_keys, f"Key mismatch: missing={missing_keys}, extra={extra_keys}"

# Compare weight values
max_diff = 0.0
mismatched_weights = []
for key in toy_keys:
toy_weight = toy_state_dict[key]
reloaded_weight = reloaded_state_dict[key]

# Convert to same dtype for comparison
if toy_weight.dtype != reloaded_weight.dtype:
reloaded_weight = reloaded_weight.to(toy_weight.dtype)

diff = torch.abs(toy_weight - reloaded_weight).max().item()
if diff > max_diff:
max_diff = diff

# Allow small numerical differences (e.g., 1e-5)
if diff > 1e-5:
mismatched_weights.append((key, diff))

if mismatched_weights:
logger.warning(f"Found {len(mismatched_weights)} mismatched weights:")
for key, diff in mismatched_weights[:10]: # Print first 10
logger.warning(f" {key}: max_diff={diff}")
assert False, f"Weight mismatch detected. Max difference: {max_diff}"

logger.info(f"Weight comparison passed! Max difference: {max_diff:.2e}")

logger.info(
f"Distributed_save test passed: Model successfully saved to {output_path} using time {(after_save - before_save):.2f}s"
)
logger.info(f" - Weight files: {len(weight_files)} file(s)")
logger.info(" - Model successfully reloaded and validated")

except Exception as e:
import traceback

traceback.print_exc()
logger.error(f"Distributed_save test skipped due to: {e}")
pytest.skip(f"Distributed_save test skipped due to: {e}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find and read the test file
find . -name "test_distributed_save_hf_weights.py" -type f

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 140


🏁 Script executed:

# Check the file size to understand how much to read
wc -l tests/functional_tests/training/test_distributed_save_hf_weights.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 142


🏁 Script executed:

# Read lines around 135-240 to verify the code snippet
sed -n '130,245p' tests/functional_tests/training/test_distributed_save_hf_weights.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 5115


🏁 Script executed:

# Also check if there are explicit precondition checks elsewhere in the test or module
grep -n "pytest.skip\|torch.cuda.is_available\|torch.distributed" tests/functional_tests/training/test_distributed_save_hf_weights.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 447


🏁 Script executed:

# Check the entire function/test to understand context better
sed -n '100,250p' tests/functional_tests/training/test_distributed_save_hf_weights.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 6427


Catch specific exception types instead of broad Exception to avoid masking real failures.

The broad exception handler masks legitimate failures that should fail the test. For distributed training tests, limit exception handling to the specific types that can occur from expected environment issues:

Suggested fix
-        except Exception as e:
-            import traceback
-
-            traceback.print_exc()
-            logger.error(f"Distributed_save test skipped due to: {e}")
-            pytest.skip(f"Distributed_save test skipped due to: {e}")
+        except (RuntimeError, OSError, FileNotFoundError) as e:
+            logger.exception("Distributed_save test skipped due to: %s", e)
+            pytest.skip(f"Distributed_save test skipped due to: {e}")

Consider adding precondition checks at the start of the test to skip before attempting complex setup:

if not torch.cuda.is_available():
    pytest.skip("CUDA is required for this test")

As per coding guidelines: "When using try-except blocks, limit the except clause to the smallest set of errors possible." Also document GPU requirements in the test docstring.

🧰 Tools
🪛 Ruff (0.14.14)

[warning] 225-225: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)


[warning] 235-235: Do not catch blind exception: Exception

(BLE001)


[warning] 239-239: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

🤖 Prompt for AI Agents
In `@tests/functional_tests/training/test_distributed_save_hf_weights.py` around
lines 135 - 240, Replace the broad "except Exception as e" in the save/reload
block with targeted exception handling and add precondition skips: at the start
of the test (before the complex setup/try) check torch.cuda.is_available() and
skip if false, and check torch.distributed.is_initialized() or expected world
size and skip if the distributed environment is not present; then restrict the
except clause around the save/load sequence to only environment-related errors
(e.g., except (RuntimeError, OSError) as e) so real test failures (assertions,
logic errors) are not swallowed; also update the test docstring to state that
CUDA and a distributed environment are required.

Comment on lines +221 to +225
if mismatched_weights:
logger.warning(f"Found {len(mismatched_weights)} mismatched weights:")
for key, diff in mismatched_weights[:10]: # Print first 10
logger.warning(f" {key}: max_diff={diff}")
assert False, f"Weight mismatch detected. Max difference: {max_diff}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat -n tests/functional_tests/training/test_distributed_save_hf_weights.py | sed -n '210,230p'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1321


🏁 Script executed:

# Let's also check what's before to understand the context
cat -n tests/functional_tests/training/test_distributed_save_hf_weights.py | sed -n '200,235p'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2066


🏁 Script executed:

# Check if there are similar patterns in other test files
rg "assert False" tests/ -A 1 -B 1

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 25300


🏁 Script executed:

# Check the coding guidelines for any specific guidance on assertions in tests
grep -i "assert" CODING_GUIDELINES.md 2>/dev/null || echo "File not found or no matches"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 99


Replace assert False with an explicit exception.

assert statements can be stripped with the -O optimization flag; use raise AssertionError to ensure failure behavior is preserved.

Suggested fix
-                    assert False, f"Weight mismatch detected. Max difference: {max_diff}"
+                    raise AssertionError(f"Weight mismatch detected. Max difference: {max_diff}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if mismatched_weights:
logger.warning(f"Found {len(mismatched_weights)} mismatched weights:")
for key, diff in mismatched_weights[:10]: # Print first 10
logger.warning(f" {key}: max_diff={diff}")
assert False, f"Weight mismatch detected. Max difference: {max_diff}"
if mismatched_weights:
logger.warning(f"Found {len(mismatched_weights)} mismatched weights:")
for key, diff in mismatched_weights[:10]: # Print first 10
logger.warning(f" {key}: max_diff={diff}")
raise AssertionError(f"Weight mismatch detected. Max difference: {max_diff}")
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 225-225: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)

🤖 Prompt for AI Agents
In `@tests/functional_tests/training/test_distributed_save_hf_weights.py` around
lines 221 - 225, Replace the flaky assertion with an explicit exception: where
the code checks mismatched_weights (the block that logs via logger.warning and
iterates over mismatched_weights), remove "assert False, f'Weight mismatch
detected. Max difference: {max_diff}'" and instead raise an AssertionError with
the same message (e.g. raise AssertionError(f"Weight mismatch detected. Max
difference: {max_diff}")). This ensures the test always fails even when Python
is run with -O optimizations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants