Skip to content

Commit 362e46a

Browse files
committed
fix straggler detection reporting
Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent 15398e0 commit 362e46a

File tree

3 files changed

+60
-28
lines changed

3 files changed

+60
-28
lines changed

src/megatron/bridge/training/nvrx_straggler.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,42 +94,46 @@ def wrap_train_step_function(self, train_step_func: Callable) -> Callable:
9494
"""
9595
Wrap the training step function with straggler detection monitoring.
9696
97+
The NVRx straggler detector instruments functions to measure CUDA kernel
98+
execution times. This method wraps the train_step function so that each
99+
call is profiled for straggler detection.
100+
97101
Args:
98102
train_step_func: The actual training step function to wrap for monitoring.
99103
100104
Returns:
101-
The wrapped training step function.
105+
The wrapped training step function that should be used instead of the original.
106+
If wrapping fails or is disabled, returns the original function.
102107
"""
103108

104109
if not self.initialized or not self.config.enabled:
105110
return train_step_func
106111

107112
if self.wrapped_function is not None:
108113
self.logger.warning("Train step function already wrapped. Skipping.")
109-
return train_step_func
114+
return self.wrapped_function
110115

111116
try:
112-
# Create a wrapper object with train_step method for nvidia-resiliency-ext
113-
# TODO: See if NVRx can support functions directly without needing them attached to a class
117+
# Create a wrapper object with train_step method for nvidia-resiliency-ext.
118+
# NVRx requires a method on an object, not a standalone function.
119+
# We store the wrapper object to prevent garbage collection.
114120
class TrainStepWrapper:
115121
def __init__(self, func):
116122
self.train_step = func
117-
self._original_func = func
118123

119124
def __call__(self, *args, **kwargs):
120-
return self._original_func(*args, **kwargs)
125+
return self.train_step(*args, **kwargs)
121126

122127
wrapper_obj = TrainStepWrapper(train_step_func)
123128

124129
# Create a callable ID for the training step function
130+
# wrap_callables modifies wrapper_obj.train_step in-place to add instrumentation
125131
callable_id = straggler.CallableId(wrapper_obj, "train_step")
126132
straggler.Detector.wrap_callables(callable_ids=[callable_id])
127133

128-
self.wrapped_function = train_step_func
129-
self.logger.debug("Train step function wrapped for NVRx straggler detection.")
130-
131-
# Return the original function since the wrapper is just for nvidia-resiliency-ext
132-
return train_step_func
134+
# Store the wrapped function to prevent garbage collection and enable reuse
135+
self.wrapped_function = wrapper_obj
136+
return wrapper_obj
133137

134138
except Exception as e:
135139
self.logger.warning(f"Failed to wrap train step function with NVRx: {e}. Continuing without wrapping.")
@@ -145,7 +149,6 @@ def check_stragglers(self, global_rank: int) -> bool:
145149
Returns:
146150
True if stragglers were detected and stop_if_detected is True, False otherwise.
147151
"""
148-
149152
if not self.initialized or not self.config.enabled:
150153
return False
151154

@@ -242,7 +245,7 @@ def _print_gpu_scores(self, report) -> None:
242245
num_best=self.config.num_gpu_perf_scores_to_print,
243246
num_worst=self.config.num_gpu_perf_scores_to_print,
244247
)
245-
self.logger.info(f"\nGPU relative performance:\n{rel_perf_str}")
248+
self.logger.info(f"GPU relative performance:\n{rel_perf_str}")
246249

247250
if self.config.calc_individual_gpu_perf:
248251
indiv_perf_str = self._format_gpu_scores(
@@ -251,7 +254,7 @@ def _print_gpu_scores(self, report) -> None:
251254
num_best=self.config.num_gpu_perf_scores_to_print,
252255
num_worst=self.config.num_gpu_perf_scores_to_print,
253256
)
254-
self.logger.info(f"\nGPU individual performance:\n{indiv_perf_str}")
257+
self.logger.info(f"GPU individual performance:\n{indiv_perf_str}")
255258

256259
def _log_gpu_scores(self, report) -> None:
257260
"""Log GPU performance scores as structured data."""
@@ -299,11 +302,9 @@ def _gather_flag_from_rank0(self, flag: bool) -> bool:
299302
def shutdown(self) -> None:
300303
"""Shutdown the straggler detector."""
301304
if self.initialized and self.config.enabled:
302-
self.logger.info("Shutting down NVRx straggler detection...")
303305
straggler.Detector.shutdown()
304306
self.initialized = False
305307
self.wrapped_function = None
306-
self.logger.info("NVRx straggler detection shutdown complete.")
307308

308309

309310
def check_nvrx_straggler_detection(nvrx_straggler_manager: Optional["NVRxStragglerDetectionManager"]) -> bool:

src/megatron/bridge/training/train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,14 @@ def train(
184184

185185
# Initialize NVRx straggler detection if enabled
186186
nvrx_straggler_manager = global_state.nvrx_straggler_manager
187+
wrapped_train_step = train_step # Default to original function
187188
if nvrx_straggler_manager is not None:
188189
try:
189190
# Initialize the straggler detector first
190191
nvrx_straggler_manager.initialize()
191192
# Wrap the train_step function for monitoring
192-
# Note: The nvidia-resiliency-ext library will monitor the actual train_step calls
193-
nvrx_straggler_manager.wrap_train_step_function(train_step)
193+
# The wrapped function must be used instead of the original to collect profiling data
194+
wrapped_train_step = nvrx_straggler_manager.wrap_train_step_function(train_step)
194195
except Exception as e:
195196
print_rank_0(f"Failed to initialize NVRx straggler detection: {e}")
196197
# Set to None to disable further checks
@@ -335,7 +336,7 @@ def train(
335336
grad_norm,
336337
num_zeros_in_grad,
337338
log_max_attention_logit,
338-
) = train_step(
339+
) = wrapped_train_step(
339340
wrapped_forward_step_func,
340341
train_data_iterator,
341342
model,

tests/unit_tests/training/test_nvrx_straggler.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,41 +280,66 @@ def dummy_func():
280280
def test_wrap_train_step_function_already_wrapped(self, manager, mock_straggler_module):
281281
"""Test wrapping when already wrapped."""
282282
manager.initialized = True
283-
manager.wrapped_function = Mock()
283+
existing_wrapper = Mock()
284+
manager.wrapped_function = existing_wrapper
284285

285286
def dummy_func():
286287
pass
287288

288289
result = manager.wrap_train_step_function(dummy_func)
289290

290-
assert result is dummy_func
291+
# Should return the already-wrapped function, not the new dummy_func
292+
assert result is existing_wrapper
291293
mock_straggler_module.Detector.wrap_callables.assert_not_called()
292294

293295
def test_wrap_train_step_function_success(self, manager, mock_straggler_module):
294296
"""Test successful wrapping."""
295297
manager.initialized = True
296298

297299
def dummy_func():
298-
pass
300+
return "original_result"
299301

300302
result = manager.wrap_train_step_function(dummy_func)
301303

302-
assert result is dummy_func
303-
assert manager.wrapped_function is dummy_func
304+
# The result should be a callable wrapper object, not the original function
305+
assert result is not dummy_func
306+
assert callable(result)
304307

305-
# Verify CallableId was called with a TrainStepWrapper object, not the original function
308+
# The wrapped_function should be the wrapper object
309+
assert manager.wrapped_function is result
310+
311+
# Verify CallableId was called with the wrapper object
306312
mock_straggler_module.CallableId.assert_called_once()
307313
call_args = mock_straggler_module.CallableId.call_args[0]
308314
assert len(call_args) == 2
309315
wrapper_obj, method_name = call_args
310316
assert method_name == "train_step"
311317

312-
# Verify the wrapper object has the train_step method that wraps our dummy_func
318+
# Verify the wrapper object has the train_step method
313319
assert hasattr(wrapper_obj, "train_step")
314-
assert wrapper_obj.train_step is dummy_func
315320

316321
mock_straggler_module.Detector.wrap_callables.assert_called_once()
317322

323+
def test_wrap_train_step_function_callable(self, manager, mock_straggler_module):
324+
"""Test that the wrapped function is callable and routes through train_step."""
325+
manager.initialized = True
326+
327+
call_count = [0]
328+
329+
def dummy_func(*args, **kwargs):
330+
call_count[0] += 1
331+
return "result"
332+
333+
result = manager.wrap_train_step_function(dummy_func)
334+
335+
# The wrapper should be callable
336+
assert callable(result)
337+
338+
# Calling the wrapper should route through train_step (which is the original func before wrap_callables)
339+
# Note: In real usage, wrap_callables modifies train_step in-place, but in tests it's mocked
340+
result()
341+
assert call_count[0] == 1
342+
318343
def test_check_stragglers_disabled(self, manager, mock_straggler_module):
319344
"""Test check_stragglers when disabled."""
320345
manager.config.enabled = False
@@ -469,6 +494,7 @@ def test_shutdown_disabled(self, manager, mock_straggler_module):
469494
def test_shutdown_success(self, manager, mock_straggler_module):
470495
"""Test successful shutdown."""
471496
manager.initialized = True
497+
manager.wrapped_function = Mock() # Simulate a wrapped function
472498

473499
manager.shutdown()
474500

@@ -584,7 +610,11 @@ def train_step():
584610
return "training"
585611

586612
wrapped_func = manager.wrap_train_step_function(train_step)
587-
assert wrapped_func is train_step
613+
# The wrapped function should NOT be the same as the original - it should be a wrapper
614+
assert wrapped_func is not train_step
615+
assert callable(wrapped_func)
616+
# The wrapper should be stored
617+
assert manager.wrapped_function is wrapped_func
588618

589619
# Test straggler detection function
590620
with patch("torch.distributed.is_initialized", return_value=False):

0 commit comments

Comments
 (0)