-
Notifications
You must be signed in to change notification settings - Fork 607
[pyTorch] CPU performance optimizations #2439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/te-ci pytorch |
| def fast_set_attr(self, name: str, value: Any) -> None: | ||
| self.__dict__[name] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume we are separating out this function so we can manually avoid overheads from __setattr__ and dict? Doing some benchmarking:
dictread: 9 nsdictwrite: 13 nsdictin: 9 nsdict.get: 14 ns- Function call: 9 ns
- Class attr read: 3 ns
- Class attr write: 5 ns
- Class custom getattr: 101 ns
- Class custom setattr: 134 ns
Benchmarking script
I ran the following on a GB200 node. For the dict times, I subtracted out the overhead from list reads. For the class getattr/setattr times, I subtracted out the overhead from range.
import contextlib
import time
class Timer:
"""Measure time interval."""
def __init__(self) -> None:
self._start = None
self._end = None
def time(self) -> float:
"""CPU time interval in seconds."""
return self._end - self._start
@contextlib.contextmanager
def context(self):
"""Context manager to capture time interval."""
self._start = time.perf_counter()
yield
self._end = time.perf_counter()
def main() -> None:
# Options
iters = 1024 * 1024
# Timer
timer = Timer()
# Dummy data
str_list = ["lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit"]
str_list = [str_list[i % len(str_list)] for i in range(iters)]
str_dict = {s: len(s) for s in str_list}
class PlainClass:
def __init__(self) -> None:
self.attr = 1
class CustomGetattrSetattrClass:
def __init__(self) -> None:
self.attr = 1
def __getattribute__(self, name):
return super().__getattribute__(name)
def __setattr__(self, name, val):
super().__setattr__(name, val)
# Timer overhead
with timer.context():
pass
print(f"Timer overhead: {timer.time() * 1e9 / iters} ns/iter")
# Range loop
with timer.context():
for _ in range(iters):
pass
print(f"Range loop: {timer.time() * 1e9 / iters} ns/iter")
# List loop
with timer.context():
for _ in str_list:
pass
print(f"List loop: {timer.time() * 1e9 / iters} ns/iter")
# Empty range+enumerate loop
with timer.context():
for i, j in enumerate(range(iters)):
pass
print(f"Range+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")
# Empty range+enumerate loop
with timer.context():
for i, s in enumerate(str_list):
pass
print(f"List+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")
# List reads
with timer.context():
for i in range(iters):
str_list[i]
print(f"List reads: {timer.time() * 1e9 / iters} ns/iter")
# Dict reads
with timer.context():
for i in range(iters):
str_dict[str_list[i]]
print(f"Dict reads: {timer.time() * 1e9 / iters} ns/iter")
# Dict get
with timer.context():
for i in range(iters):
str_dict.get(str_list[i], None)
print(f"Dict gets: {timer.time() * 1e9 / iters} ns/iter")
# Dict writes
with timer.context():
for i in range(iters):
str_dict[str_list[i]] = i
print(f"Dict writes: {timer.time() * 1e9 / iters} ns/iter")
# Dict membership
with timer.context():
for i in range(iters):
str_list[i] in str_dict
print(f"Dict membership: {timer.time() * 1e9 / iters} ns/iter")
# Function call
def func() -> None:
pass
with timer.context():
for _ in range(iters):
func()
print(f"Function call: {timer.time() * 1e9 / iters} ns/iter")
# Function call
func = lambda: None
with timer.context():
for _ in range(iters):
func()
print(f"Lambda call: {timer.time() * 1e9 / iters} ns/iter")
# Class attr read
myobj = PlainClass()
with timer.context():
for _ in range(iters):
_ = myobj.attr
print(f"Class attr read: {timer.time() * 1e9 / iters} ns/iter")
# Class attr write
myobj = PlainClass()
with timer.context():
for i in range(iters):
myobj.attr = i
print(f"Class attr write: {timer.time() * 1e9 / iters} ns/iter")
# getattr
myobj = PlainClass()
with timer.context():
for _ in range(iters):
getattr(myobj, "attr", None)
print(f"getattr: {timer.time() * 1e9 / iters} ns/iter")
# getattr
myobj = PlainClass()
with timer.context():
for i in range(iters):
setattr(myobj, "attr", i)
print(f"setattr: {timer.time() * 1e9 / iters} ns/iter")
# Class custom getattr
myobj = CustomGetattrSetattrClass()
with timer.context():
for _ in range(iters):
_ = myobj.attr
print(f"Class custom getattr: {timer.time() * 1e9 / iters} ns/iter")
# Class custom setattr
myobj = CustomGetattrSetattrClass()
with timer.context():
for i in range(iters):
myobj.attr = i
print(f"Class custom setattr: {timer.time() * 1e9 / iters} ns/iter")
if __name__ == "__main__":
main()How much perf difference do you observe from fast_set_attr? I could see how it could save us ~1 us of overhead, but it would be good to make sure before making the code messier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to comment too much on the perf results yet since up till now they all come from my machine and not a real cluster, but that anecdotal evidence shows that the time of the small test of just running BF16 Linear layer forward for many iterations after the proposed code changes go from 9.2 to 7.7 s. The fast_set_attr alone brought it to ~8.4s.
I will test it properly and report the timings in the description of the PR.
Now, about introducing the separate function - since ultimately this is the optimization that you came up with at some point, there already was the machinery to not do the expensive Module.set_attr for some parameters. The problem that I see is discoverability - if people do not study that code very cautiously they will not realize that they should not just do self.something = something. Therefore I think we should actually go a more explicit way and in the set_attr of TE module just error out with a message to either use fast_set_attr for the things we are sure are just small values (since the usage of dict directly has some problems BTW since it e.g. bypasses properties and stuff) and use a new function, let's call it just set_attr for anything where we need the full machinery.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to ban self.something = something. I think readability and safety are more important for non-performance-critical things like initialization and checkpointing. It would be better to make this function an advanced internal implementation with a name like _fast_setattr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would we then make sure that this does not resurface in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went with the explicit setattr calls and having a warning issued when the regular setattr function is used. That way the users can still use the regular setattr call if they want, but for the internal development we make sure during testing that the warning does not trigger. To make the code less ugly we only turn on the warning after the constructor is finished - that way we can still use the nice syntax during construction (where there are the most occurences) since we do not care about the speed there.
5eefe3e to
1c7d896
Compare
948747b to
c4e380f
Compare
|
/te-ci pytorch |
Greptile SummaryThis PR implements CPU performance optimizations for TransformerEngine PyTorch modules by reducing overhead from Python attribute assignments and context manager usage. Key Changes:
Issues Discussed in Previous Comments:
Testing Approach:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Linear as Linear Module
participant Base as TransformerEngineBaseModule
participant NVTX
User->>Linear: forward(inp)
Linear->>Base: prepare_forward(inp)
Base->>Base: fast_setattr() for attributes
Base->>Base: init_fp8_metadata()
Base->>NVTX: nvtx_range_push("Linear forward")
Base-->>Linear: return processed_inp
Note over Linear: Forward computation
Linear->>Linear: _get_weight_and_bias_tensors()
Linear->>Linear: _get_quantizers()
Linear->>Linear: linear_fn() - CUDA operations
Linear->>Base: end_forward()
Base->>Base: restore_fp8_meta_tensors (if needed)
Base->>NVTX: nvtx_range_pop()
Linear-->>User: return output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 5 comments
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | ||
|
|
||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The refactoring from context manager to explicit prepare_forward() and end_forward() calls introduces an NVTX range stack leak risk. If an exception is raised anywhere between line 1402 (where prepare_forward() calls nvtx_range_push()) and line 1474 (where end_forward() calls nvtx_range_pop()), the NVTX range will never be popped from the stack.
The original context manager approach (with self.prepare_forward()) was safer because it guaranteed cleanup even when exceptions occurred. Without proper exception handling, NVTX ranges can accumulate on the stack, corrupting profiling data.
Consider wrapping the code between prepare_forward() and end_forward() in a try-finally block:
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
try:
# ... all the processing code ...
out = linear_fn(...)
finally:
self.end_forward()This same issue exists in layernorm_linear.py, layernorm_mlp.py, and grouped_linear.py.
| inp = self.prepare_forward( | ||
| inp, allow_non_contiguous=False # removed .contiguous from inside the layer | ||
| ) as inp: | ||
| ) | ||
|
|
||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
|
|
||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same NVTX range stack leak vulnerability as in linear.py. If an exception occurs between prepare_forward() (line 1519) and end_forward() (line 1601), the NVTX range pushed in prepare_forward() will never be popped, corrupting the profiling stack.
The original context manager pattern was exception-safe. Consider using try-finally to ensure end_forward() is always called:
inp = self.prepare_forward(inp, allow_non_contiguous=False)
try:
# ... processing code ...
out = fwd_fn(...)
finally:
self.end_forward()| inp = self.prepare_forward(inp, num_gemms=2) | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
|
|
||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.bias_gelu_nvfusion = False | ||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.bias_gelu_nvfusion = False | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same NVTX range stack leak vulnerability as in linear.py and layernorm_linear.py. If an exception occurs between prepare_forward() (line 2057) and end_forward() (line 2159), the NVTX range will leak.
Consider using try-finally to ensure cleanup:
inp = self.prepare_forward(inp, num_gemms=2)
try:
# ... processing code ...
out = fwd_fn(...)
finally:
self.end_forward()| inp = self.prepare_forward(inp, num_gemms=self.num_gemms) | ||
| weight_tensors = self._get_weight_tensors() | ||
| bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] | ||
|
|
||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
|
|
||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
|
|
||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
|
|
||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same NVTX range stack leak vulnerability. If an exception occurs between prepare_forward() (line 793) and end_forward() (line 847), the NVTX range will leak.
Consider using try-finally to ensure cleanup:
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
try:
# ... processing code ...
out = linear_fn(...)
finally:
self.end_forward()
Greptile OverviewGreptile SummaryThis PR attempts to optimize CPU overhead by:
Critical Issues Found1. NVTX Range Imbalance on Exceptions (HIGH SEVERITY)The refactoring from context managers to manual 2. setattr Doesn't Actually Optimize (HIGH SEVERITY)The new 3. Multiple RuntimeWarning Violations (CRITICAL SEVERITY)Six locations in base.py use direct attribute assignment after initialization (lines 965, 966, 1558, 1559, 1565, 1581, 1608). Since pytest.ini now treats RuntimeWarnings as errors, all tests will fail. Positive Aspects
RecommendationThis PR cannot be merged in its current state due to the RuntimeWarning violations that will cause all tests to fail. The NVTX exception safety issue is also critical for production use. Confidence Score: 0/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant Module as Linear/LayerNorm Module
participant Base as TransformerEngineBaseModule
participant NVTX as NVTX Range Stack
participant FP8 as FP8GlobalStateManager
User->>Module: forward(inp)
Module->>Base: prepare_forward(inp)
Base->>Base: init_fp8_metadata()
Base->>Base: get_amp_dtype()
Base->>FP8: Copy FP8 metadata (if recompute)
Base->>NVTX: nvtx_range_push("Module forward")
Note over NVTX: Range pushed - needs cleanup!
Base-->>Module: returns processed inp
alt Exception occurs during forward
Module->>Module: Compute quantizers/weights
Module--xModule: Exception raised!
Note over NVTX: ⚠️ BUG: nvtx_range_pop() never called!
Note over NVTX: Stack becomes imbalanced
else Normal execution
Module->>Module: Compute quantizers/weights
Module->>Module: Call linear_fn()
Module->>Base: end_forward()
Base->>FP8: Restore FP8 metadata (if recompute)
Base->>NVTX: nvtx_range_pop()
Note over NVTX: Range properly cleaned up
Base-->>Module: return
Module-->>User: return output
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 7 comments
| def end_forward(self): | ||
| """ | ||
| Required to be called at the end of the forward function to properly handle | ||
| DelayedScaling metadata handling and the NVTX ranges. | ||
| """ | ||
| delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() | ||
| if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): | ||
| FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) | ||
| nvtx_range_pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new non-context-manager approach to prepare_forward() and end_forward() creates a critical bug: if an exception is raised between the prepare_forward() call and the end_forward() call in any forward method, the NVTX range will never be popped, leading to an imbalanced NVTX stack.
The original context manager approach with with self.prepare_forward(...) as inp: guaranteed that cleanup would happen even on exceptions. The new approach breaks this guarantee.
Impact: All forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear are affected. Any exception during forward (e.g., shape mismatch, CUDA OOM, assertion failure) will leave the NVTX range stack corrupted.
Recommendation: Either:
- Revert to the context manager approach (accepting the performance cost), or
- Wrap the forward logic in try/finally to ensure
end_forward()is always called:
inp = self.prepare_forward(inp, ...)
try:
# forward logic here
out = linear_fn(...)
finally:
self.end_forward()| def __setattr__(self, name: str, value: Any) -> None: | ||
| if name in TransformerEngineBaseModule._fast_setattr_names: | ||
| # torch.nn.Module has a custom __setattr__ that handles | ||
| # modules, parameters, and buffers. This is unnecessary | ||
| # overhead when setting plain attrs. | ||
| self.__dict__[name] = value | ||
| else: | ||
| # Default case | ||
| super().__setattr__(name, value) | ||
| if "_initialized" in self.__dict__ and self._initialized: | ||
| warnings.warn( | ||
| """The default implementation of torch.nn.Module introduces significant CPU overhead | ||
| when setting attributes and is therefore not recommended. Please use the explicit | ||
| calls (fast_setattr for setting regular values and module_setattr for setting | ||
| parameters, children modules and buffers).""", | ||
| RuntimeWarning, | ||
| ) | ||
| super().__setattr__(name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __setattr__ override doesn't actually optimize anything - it still calls super().__setattr__(name, value) after emitting the warning. This means every attribute assignment still goes through the slow PyTorch Module path, defeating the purpose of the optimization.
The warning is also emitted on EVERY attribute assignment after initialization, which would be extremely noisy and unhelpful for debugging.
Additionally, there are multiple places in the codebase that still use direct attribute assignment after initialization:
- Line 965:
self.tp_group = tp_groupinset_tensor_parallel_group() - Line 966:
self.tp_group_initialized = True - Lines 1558, 1559, 1565, 1581, 1608: Various debug-related attribute assignments
Since pytest.ini now treats RuntimeWarnings as errors, all of these will cause test failures.
Recommendation: Either:
- Remove the
__setattr__override entirely and rely on developers remembering to usefast_setattr, or - Update all the remaining direct assignments to use
fast_setattrormodule_setattr
Additional Comments (5)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
| nvtx_range_push(self.__class__.__name__ + " forward") | ||
| if not allow_non_contiguous and not inp.is_contiguous(): | ||
| inp = inp.contiguous() | ||
| return inp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prepare_forward() method now uses nvtx_range_push() but the corresponding nvtx_range_pop() is in end_forward(). This creates a critical issue: if an exception occurs in the forward pass between prepare_forward() and end_forward() (e.g., shape mismatch, CUDA OOM, assertion failures), the NVTX range will never be popped, corrupting the NVTX profiling stack.
While the old context manager pattern guaranteed cleanup via __exit__, the new pattern requires manual calls. Consider wrapping the entire forward body in try-finally or document that modules must ensure end_forward() is called even on exceptions.
| def __setattr__(self, name: str, value: Any) -> None: | ||
| if name in TransformerEngineBaseModule._fast_setattr_names: | ||
| # torch.nn.Module has a custom __setattr__ that handles | ||
| # modules, parameters, and buffers. This is unnecessary | ||
| # overhead when setting plain attrs. | ||
| self.__dict__[name] = value | ||
| else: | ||
| # Default case | ||
| super().__setattr__(name, value) | ||
| if "_initialized" in self.__dict__ and self._initialized: | ||
| warnings.warn( | ||
| """The default implementation of torch.nn.Module introduces significant CPU overhead | ||
| when setting attributes and is therefore not recommended. Please use the explicit | ||
| calls (fast_setattr for setting regular values and module_setattr for setting | ||
| parameters, children modules and buffers).""", | ||
| RuntimeWarning, | ||
| ) | ||
| super().__setattr__(name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __setattr__ override emits a RuntimeWarning every time it's called after initialization, but it still calls super().__setattr__() to perform the assignment. This means:
- The warning is emitted but the "slow" operation still happens - there's no actual optimization here
- With pytest.ini treating RuntimeWarning as error, any accidental direct attribute assignment will cause test failures
- The warning message suggests using
fast_setattrbut doesn't prevent the slow path
If the goal is CPU optimization, this should either:
- Raise an exception instead of warning (fail-fast), OR
- Actually use the fast path by default after initialization
The current implementation warns about performance but doesn't fix it.
| filterwarnings= | ||
| error::RuntimeWarning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Treating RuntimeWarning as an error will cause test failures for any code that uses direct attribute assignment after module initialization. Several locations in base.py still use direct assignment (see other comments), which will fail with this configuration.
This is appropriate for enforcing the fast_setattr() pattern, but all existing direct assignments must be fixed first, or this change should be in a separate PR after fixing all assignment sites.
Additional Comments (5)
This method is called after module initialization in user code, so it will fail. The assignments should use
These should use
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 6 comments
| def prepare_forward( | ||
| self, | ||
| inp: torch.Tensor, | ||
| num_gemms: int = 1, | ||
| allow_non_contiguous: bool = False, | ||
| allow_different_data_and_param_types: bool = False, | ||
| ) -> Generator[torch.Tensor, None, None]: | ||
| """Checks and prep for FWD. | ||
| The context manager is needed because there isn't a way for a module to know | ||
| if it's the last FP8 module in the forward autocast. It is useful | ||
| to setup the forward aggregated amax reduction for every module | ||
| just in case. The autocast exit will pick up the most recent one. | ||
| """ | ||
| self.allow_different_data_and_param_types = allow_different_data_and_param_types | ||
| self.forwarded_at_least_once = True | ||
| ) -> torch.Tensor: | ||
| """Checks and prepares for FWD execution.""" | ||
| self.fast_setattr( | ||
| "allow_different_data_and_param_types", allow_different_data_and_param_types | ||
| ) | ||
| self.fast_setattr("forwarded_at_least_once", True) | ||
|
|
||
| # Activation recomputation is used and this is the second forward phase. | ||
| if self.fp8 and in_fp8_activation_recompute_phase(): | ||
| delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() | ||
| FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) | ||
| else: | ||
| assert inp.is_cuda, "TransformerEngine needs CUDA." | ||
|
|
||
| if self.tp_size > 1: | ||
| assert self.tp_group_initialized, "TP group not initialized." | ||
|
|
||
| self.set_activation_dtype(inp) | ||
| self.init_fp8_metadata(num_gemms=num_gemms) | ||
| self._check_weight_tensor_recipe_correspondence() | ||
|
|
||
| delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() | ||
| if delayed_scaling_recipe: | ||
| if self.sequence_parallel: | ||
| assert self.fp8_meta["recipe"].reduce_amax, ( | ||
| "Amax reduction across tensor parallel group is " | ||
| "necessary when using sequence parallelism with FP8." | ||
| ) | ||
|
|
||
| if not FP8GlobalStateManager.fp8_graph_capturing(): | ||
| FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) | ||
|
|
||
| # Activation recomputation is used and this is the first forward phase. | ||
| if self.training and is_fp8_activation_recompute_enabled(): | ||
| FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) | ||
|
|
||
| with get_nvtx_range_context(self.__class__.__name__ + " forward"): | ||
| if not allow_non_contiguous and not inp.is_contiguous(): | ||
| inp = inp.contiguous() | ||
| yield inp | ||
| nvtx_range_push(self.__class__.__name__ + " forward") | ||
| if not allow_non_contiguous and not inp.is_contiguous(): | ||
| inp = inp.contiguous() | ||
| return inp | ||
|
|
||
| def end_forward(self): | ||
| """ | ||
| Required to be called at the end of the forward function to properly handle | ||
| DelayedScaling metadata handling and the NVTX ranges. | ||
| """ | ||
| delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() | ||
| if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): | ||
| FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) | ||
| nvtx_range_pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX Stack Corruption on Exceptions
The refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. The context manager's __exit__ method guaranteed that cleanup code would run even if exceptions occurred. Now, if ANY exception is raised between these calls, nvtx_range_pop() will never be called, corrupting the NVTX profiling stack.
Exception scenarios that will corrupt the stack:
- Line 1058:
assert inp.is_cuda- AssertionError if input is on CPU - Line 1061:
assert self.tp_group_initialized- AssertionError if TP not initialized - Line 1070:
assert self.fp8_meta["recipe"].reduce_amax- AssertionError for sequence parallel - Any CUDA errors in the forward computation (OOM, kernel failures, etc.)
- Shape mismatches, dtype errors, or other runtime errors
Impact: NVTX profiling will be permanently broken after the first exception, making debugging extremely difficult.
Solution: Either:
- Wrap the entire forward body in try-finally to guarantee end_forward() is called, OR
- Revert to using the context manager (prepare_forward_ctx), OR
- Remove all assertions and exceptions from the code path between these calls (not realistic)
The attention module correctly uses prepare_forward_ctx (the context manager version) which avoids this issue.
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | ||
|
|
||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: Exception safety issue - NVTX stack corruption risk
Between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), there are multiple operations that can raise exceptions:
_get_weight_and_bias_tensors()- could fail with various errors_get_quantizers()/_get_debug_quantizers()- could raise exceptionslinear_fn()- the actual forward computation can fail with CUDA errors, shape mismatches, OOM, etc.
If ANY exception occurs, end_forward() won't be called, leaving the NVTX stack corrupted.
Solution: Wrap in try-finally or use the context manager version (prepare_forward_ctx) like the attention module does.
| inp = self.prepare_forward( | ||
| inp, allow_non_contiguous=False # removed .contiguous from inside the layer | ||
| ) as inp: | ||
| ) | ||
|
|
||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: Exception safety issue - NVTX stack corruption risk
Same issue as Linear module. Between prepare_forward() and end_forward(), multiple operations can raise exceptions, leaving the NVTX stack corrupted if they do.
Solution: Wrap in try-finally or use the context manager version (prepare_forward_ctx).
| inp = self.prepare_forward(inp, num_gemms=2) | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
|
|
||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.bias_gelu_nvfusion = False | ||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.fast_setattr("bias_gelu_nvfusion", False) | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
|
|
||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: Exception safety issue - NVTX stack corruption risk
Same exception safety issue. The forward computation between prepare_forward() and end_forward() can raise exceptions, corrupting the NVTX stack.
Solution: Wrap in try-finally or use the context manager version (prepare_forward_ctx).
| inp = self.prepare_forward(inp, num_gemms=self.num_gemms) | ||
| weight_tensors = self._get_weight_tensors() | ||
| bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] | ||
|
|
||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
|
|
||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
|
|
||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
|
|
||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
|
|
||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: Exception safety issue - NVTX stack corruption risk
Same exception safety issue as other modules. Any exception between prepare_forward() and end_forward() will corrupt the NVTX stack.
Solution: Wrap in try-finally or use the context manager version (prepare_forward_ctx).
| def fast_setattr(self, name: str, value: Any) -> None: | ||
| """ | ||
| Fast version of the Module's set attribute function. | ||
| Should be used for regular attributes, but not properties nor parameters/buffers. | ||
| """ | ||
| self.__dict__[name] = value | ||
|
|
||
| def module_setattr(self, name: str, value: Any) -> None: | ||
| """ | ||
| Regular version of the Module's set attribute function. | ||
| Should be used only when the fast version cannot be used - for the properties, | ||
| parameters and buffers. | ||
| """ | ||
| super().__setattr__(name, value) | ||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: | ||
| if name in TransformerEngineBaseModule._fast_setattr_names: | ||
| # torch.nn.Module has a custom __setattr__ that handles | ||
| # modules, parameters, and buffers. This is unnecessary | ||
| # overhead when setting plain attrs. | ||
| self.__dict__[name] = value | ||
| else: | ||
| # Default case | ||
| super().__setattr__(name, value) | ||
| if "_initialized" in self.__dict__ and self._initialized: | ||
| warnings.warn( | ||
| """The default implementation of torch.nn.Module introduces significant CPU overhead | ||
| when setting attributes and is therefore not recommended. Please use the explicit | ||
| calls (fast_setattr for setting regular values and module_setattr for setting | ||
| parameters, children modules and buffers).""", | ||
| RuntimeWarning, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Design Issue: The setattr override doesn't actually optimize anything
The new __setattr__ implementation still calls super().__setattr__(), which means it still incurs the full PyTorch Module overhead. The warning is issued but the slow path is always taken.
To actually optimize, the code should:
- Skip the warning during
__init__(checkif "_initialized" not in self.__dict__) - After init, require explicit use of
fast_setattr()ormodule_setattr() - Consider making
__setattr__raise an error instead of a warning to enforce the pattern
Currently, this just adds overhead (checking _initialized + issuing warning) on top of the existing slow path.
| def train(self, mode: bool = True) -> Self: | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings("ignore", ".*CPU overhead.*") | ||
| return super().train(mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Workaround Needed: train() method silences the RuntimeWarning
This method filters out the "CPU overhead" warning because PyTorch's Module.train() internally uses __setattr__ to set the training flag on all child modules. This is a workaround for the fact that the base implementation can't be avoided in some cases.
This reveals a potential issue: other PyTorch internal methods that use __setattr__ will also trigger warnings. Consider:
to()device transfershalf(),float(),bfloat16()dtype conversionsregister_module(),register_buffer(),register_parameter()
These may all need similar workarounds.
| def prepare_forward( | ||
| self, | ||
| inp: torch.Tensor, | ||
| num_gemms: int = 1, | ||
| allow_non_contiguous: bool = False, | ||
| allow_different_data_and_param_types: bool = False, | ||
| ) -> Generator[torch.Tensor, None, None]: | ||
| """Checks and prep for FWD. | ||
| The context manager is needed because there isn't a way for a module to know | ||
| if it's the last FP8 module in the forward autocast. It is useful | ||
| to setup the forward aggregated amax reduction for every module | ||
| just in case. The autocast exit will pick up the most recent one. | ||
| """ | ||
| self.allow_different_data_and_param_types = allow_different_data_and_param_types | ||
| self.forwarded_at_least_once = True | ||
| ) -> torch.Tensor: | ||
| """Checks and prepares for FWD execution.""" | ||
| self.fast_setattr( | ||
| "allow_different_data_and_param_types", allow_different_data_and_param_types | ||
| ) | ||
| self.fast_setattr("forwarded_at_least_once", True) | ||
|
|
||
| # Activation recomputation is used and this is the second forward phase. | ||
| if self.fp8 and in_fp8_activation_recompute_phase(): | ||
| delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() | ||
| FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) | ||
| else: | ||
| assert inp.is_cuda, "TransformerEngine needs CUDA." | ||
|
|
||
| if self.tp_size > 1: | ||
| assert self.tp_group_initialized, "TP group not initialized." | ||
|
|
||
| self.set_activation_dtype(inp) | ||
| self.init_fp8_metadata(num_gemms=num_gemms) | ||
| self._check_weight_tensor_recipe_correspondence() | ||
|
|
||
| delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() | ||
| if delayed_scaling_recipe: | ||
| if self.sequence_parallel: | ||
| assert self.fp8_meta["recipe"].reduce_amax, ( | ||
| "Amax reduction across tensor parallel group is " | ||
| "necessary when using sequence parallelism with FP8." | ||
| ) | ||
|
|
||
| if not FP8GlobalStateManager.fp8_graph_capturing(): | ||
| FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) | ||
|
|
||
| # Activation recomputation is used and this is the first forward phase. | ||
| if self.training and is_fp8_activation_recompute_enabled(): | ||
| FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) | ||
|
|
||
| with get_nvtx_range_context(self.__class__.__name__ + " forward"): | ||
| if not allow_non_contiguous and not inp.is_contiguous(): | ||
| inp = inp.contiguous() | ||
| yield inp | ||
| nvtx_range_push(self.__class__.__name__ + " forward") | ||
| if not allow_non_contiguous and not inp.is_contiguous(): | ||
| inp = inp.contiguous() | ||
| return inp | ||
|
|
||
| def end_forward(self): | ||
| """ | ||
| Required to be called at the end of the forward function to properly handle | ||
| DelayedScaling metadata handling and the NVTX ranges. | ||
| """ | ||
| delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() | ||
| if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): | ||
| FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) | ||
| nvtx_range_pop() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] NVTX range imbalance on exceptions: The refactoring from context manager to explicit prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), the NVTX range stack will be corrupted.
This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules. Any exception (shape mismatch, CUDA OOM, assertion failure, etc.) will leave an unclosed NVTX range.
The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires all code paths to call end_forward(), but exceptions bypass this.
| def __setattr__(self, name: str, value: Any) -> None: | ||
| if name in TransformerEngineBaseModule._fast_setattr_names: | ||
| # torch.nn.Module has a custom __setattr__ that handles | ||
| # modules, parameters, and buffers. This is unnecessary | ||
| # overhead when setting plain attrs. | ||
| self.__dict__[name] = value | ||
| else: | ||
| # Default case | ||
| super().__setattr__(name, value) | ||
| if "_initialized" in self.__dict__ and self._initialized: | ||
| warnings.warn( | ||
| """The default implementation of torch.nn.Module introduces significant CPU overhead | ||
| when setting attributes and is therefore not recommended. Please use the explicit | ||
| calls (fast_setattr for setting regular values and module_setattr for setting | ||
| parameters, children modules and buffers).""", | ||
| RuntimeWarning, | ||
| ) | ||
| super().__setattr__(name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] setattr warning defeats the optimization: The __setattr__ override emits a RuntimeWarning when _initialized=True, which pytest.ini treats as an error. However, __setattr__ still calls super().__setattr__(), meaning it doesn't actually prevent or optimize anything - it just warns.
If the goal is CPU optimization, the warning should either:
- Be accompanied by raising an exception to force using
fast_setattr(), or - Implement the fast path directly in
__setattr__when appropriate
Currently, any test that triggers the warning will fail, but production code will just emit warnings with no actual optimization benefit from this implementation.
Additional Comments (4)
These direct assignments using This method is called after module initialization (when
These direct assignments will trigger RuntimeWarning since
This assignment will also trigger the RuntimeWarning.
|
Additional Comments (6)
This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules. Any exception (shape mismatch, CUDA OOM, assertion failure, etc.) will leave an unclosed NVTX range. The original context manager pattern guaranteed cleanup via
If the goal is CPU optimization, the warning should either:
Currently, any test that triggers the warning will fail, but production code will just emit warnings with no actual optimization benefit from this implementation.
And at line 1571:
If the warnings are genuinely problematic here, the underlying issue should be fixed rather than suppressed.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| nvtx_range_push(self.__class__.__name__ + " forward") | ||
| if not allow_non_contiguous and not inp.is_contiguous(): | ||
| inp = inp.contiguous() | ||
| return inp | ||
|
|
||
| def end_forward(self): | ||
| """ | ||
| Required to be called at the end of the forward function to properly handle | ||
| DelayedScaling metadata handling and the NVTX ranges. | ||
| """ | ||
| delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() | ||
| if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): | ||
| FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) | ||
| nvtx_range_pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions - will corrupt profiling stack
The refactoring from context manager (with get_nvtx_range_context(...)) to explicit nvtx_range_push() / nvtx_range_pop() calls breaks exception safety. If an exception occurs between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), the NVTX range stack will be permanently corrupted.
The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires all code paths to call end_forward(), but exceptions bypass this.
This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules. Any exception (shape mismatch, CUDA OOM, assertion failure, etc.) will leave an unclosed NVTX range, affecting all subsequent profiling.
Recommended fix: Either:
- Restore the context manager pattern, or
- Wrap all forward pass code in try/finally blocks to ensure
nvtx_range_pop()is called, or - Document that forward passes must be wrapped in exception handlers that call
end_forward()on errors
| def __setattr__(self, name: str, value: Any) -> None: | ||
| if name in TransformerEngineBaseModule._fast_setattr_names: | ||
| # torch.nn.Module has a custom __setattr__ that handles | ||
| # modules, parameters, and buffers. This is unnecessary | ||
| # overhead when setting plain attrs. | ||
| self.__dict__[name] = value | ||
| else: | ||
| # Default case | ||
| super().__setattr__(name, value) | ||
| if "_initialized" in self.__dict__ and self._initialized: | ||
| warnings.warn( | ||
| """The default implementation of torch.nn.Module introduces significant CPU overhead | ||
| when setting attributes and is therefore not recommended. Please use the explicit | ||
| calls (fast_setattr for setting regular values and module_setattr for setting | ||
| parameters, children modules and buffers).""", | ||
| RuntimeWarning, | ||
| ) | ||
| super().__setattr__(name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WARNING: setattr override defeats the optimization it's meant to enforce
The __setattr__ override emits a RuntimeWarning when _initialized=True, but then still calls super().__setattr__(name, value) which goes through PyTorch's slow path anyway. This means:
- You pay the cost of the warning check
- You still pay the full cost of PyTorch's
__setattr__ - The warning is treated as an error by pytest, so any code path using it will fail tests
This doesn't actually optimize anything - it just breaks code that uses standard attribute assignment. The optimization only works when code explicitly calls fast_setattr().
Implications:
- Any code path (including external libraries or user code) that does
module.attribute = valuewill fail in tests - The
train()method workaround (lines 1111-1114) shows this is already a problem - Properties, descriptors, and any dynamic attribute access will trigger this warning
Recommended approach:
- Remove the warning from
__setattr__- it creates more problems than it solves - Keep
fast_setattr()as an opt-in optimization - Add comments/documentation encouraging use of
fast_setattr()where performance matters
Additional Comments (2)
These direct assignments using
These direct assignments to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
Additional Comments (3)
The
The
This direct assignment to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 9 comments
| ).is_fp8_ubuf(): | ||
| fp8_grad = True | ||
|
|
||
| with self.prepare_forward( | ||
| inp, | ||
| allow_non_contiguous=isinstance(inp, QuantizedTensor), | ||
| ) as inp: | ||
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | ||
|
|
||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| self.end_forward() | ||
| if self.gemm_bias_unfused_add: | ||
| out = out + cast_if_needed(bias_tensor, self.activation_dtype) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] The refactoring from context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If any exception occurs between these calls (e.g., shape mismatch, CUDA OOM, assertion failure in the autograd function), end_forward() is never called, leaving an unclosed NVTX range via nvtx_range_push() without matching nvtx_range_pop(). This corrupts the NVTX stack for profiling.
The original context manager pattern guaranteed cleanup via __exit__. Consider either wrapping in try-finally or reverting to the context manager pattern (prepare_forward_ctx).
| ).is_fp8_ubuf(): | ||
| fp8_grad = True | ||
|
|
||
| with self.prepare_forward( | ||
| inp = self.prepare_forward( | ||
| inp, allow_non_contiguous=False # removed .contiguous from inside the layer | ||
| ) as inp: | ||
| ) | ||
|
|
||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
|
|
||
| self.end_forward() | ||
|
|
||
| if self.return_layernorm_output: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] Same NVTX range imbalance issue as Linear. If an exception occurs between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), the NVTX stack will be corrupted. The context manager automatically handled this cleanup.
| if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): | ||
| fp8_output = True | ||
|
|
||
| with self.prepare_forward(inp, num_gemms=2) as inp: | ||
| inp = self.prepare_forward(inp, num_gemms=2) | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
|
|
||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.bias_gelu_nvfusion = False | ||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.fast_setattr("bias_gelu_nvfusion", False) | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
|
|
||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] Same NVTX range imbalance issue. Exception between prepare_forward() and end_forward() will leave unclosed NVTX range, corrupting profiling stack.
|
|
||
| is_grad_enabled = torch.is_grad_enabled() | ||
|
|
||
| with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: | ||
| weight_tensors = self._get_weight_tensors() | ||
| bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] | ||
| inp = self.prepare_forward(inp, num_gemms=self.num_gemms) | ||
| weight_tensors = self._get_weight_tensors() | ||
| bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] | ||
|
|
||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
|
|
||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
|
|
||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
|
|
||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] Same NVTX range imbalance issue. Exception between prepare_forward() and end_forward() will leave unclosed NVTX range.
| def train(self, mode: bool = True) -> Self: | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings("ignore", ".*CPU overhead.*") | ||
| return super().train(mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P2] The train() method suppresses the CPU overhead warning by filtering it out. However, this means PyTorch's __setattr__ is still being called for all the child module state changes during train()/eval() mode switching. This defeats the optimization goal since the expensive __setattr__ path is still taken - it just doesn't warn about it.
The proper fix would be to override train() to use fast_setattr for the training attribute on all child modules, rather than suppressing the warning.
| def __setattr__(self, name: str, value: Any) -> None: | ||
| if name in TransformerEngineBaseModule._fast_setattr_names: | ||
| # torch.nn.Module has a custom __setattr__ that handles | ||
| # modules, parameters, and buffers. This is unnecessary | ||
| # overhead when setting plain attrs. | ||
| self.__dict__[name] = value | ||
| else: | ||
| # Default case | ||
| super().__setattr__(name, value) | ||
| if "_initialized" in self.__dict__ and self._initialized: | ||
| warnings.warn( | ||
| """The default implementation of torch.nn.Module introduces significant CPU overhead | ||
| when setting attributes and is therefore not recommended. Please use the explicit | ||
| calls (fast_setattr for setting regular values and module_setattr for setting | ||
| parameters, children modules and buffers).""", | ||
| RuntimeWarning, | ||
| ) | ||
| super().__setattr__(name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P2] The __setattr__ override still calls super().__setattr__(name, value) after emitting the warning, meaning the expensive PyTorch __setattr__ is always executed. This doesn't optimize anything - it only adds overhead by checking _initialized and potentially emitting a warning before doing the expensive operation anyway.
If the goal is CPU optimization, the __setattr__ should use fast_setattr for appropriate attributes instead of just warning. The current implementation only serves as a detection mechanism, not an optimization.
Additional Comments (3)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 7 comments
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | ||
|
|
||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] NVTX range imbalance on exceptions. The refactoring from context manager to manual prepare_forward()/end_forward() breaks exception safety. If any exception occurs between line 1402 (prepare_forward() which calls nvtx_range_push) and line 1474 (end_forward() which calls nvtx_range_pop), the NVTX range stack will be corrupted. Possible exceptions include shape mismatches, CUDA OOM, assertion failures (e.g., line 1061 in base.py: assert self.tp_group_initialized), or errors in quantizer initialization. The original context manager guaranteed cleanup via __exit__, but the new pattern has no exception handling.
| inp = self.prepare_forward( | ||
| inp, allow_non_contiguous=False # removed .contiguous from inside the layer | ||
| ) as inp: | ||
| ) | ||
|
|
||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] NVTX range imbalance on exceptions. Same issue as Linear module - if any exception occurs between prepare_forward() (line 1519, calls nvtx_range_push) and end_forward() (line 1595, calls nvtx_range_pop), the NVTX stack is corrupted. Possible exceptions include shape mismatches, CUDA errors, or assertion failures.
| inp = self.prepare_forward(inp, num_gemms=self.num_gemms) | ||
| weight_tensors = self._get_weight_tensors() | ||
| bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] | ||
|
|
||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
|
|
||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
|
|
||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
|
|
||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] NVTX range imbalance on exceptions. Same issue as other modules - if any exception occurs between prepare_forward() (line 793, calls nvtx_range_push) and end_forward() (line 847, calls nvtx_range_pop), the NVTX stack is corrupted.
| inp = self.prepare_forward(inp, num_gemms=2) | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
|
|
||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.bias_gelu_nvfusion = False | ||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.fast_setattr("bias_gelu_nvfusion", False) | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
|
|
||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] NVTX range imbalance on exceptions. Same issue as other modules - if any exception occurs between prepare_forward() (line 2057, calls nvtx_range_push) and end_forward() (line 2167, calls nvtx_range_pop), the NVTX stack is corrupted.
Additional Comments (3)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
| def prepare_forward( | ||
| self, | ||
| inp: torch.Tensor, | ||
| num_gemms: int = 1, | ||
| allow_non_contiguous: bool = False, | ||
| allow_different_data_and_param_types: bool = False, | ||
| ) -> Generator[torch.Tensor, None, None]: | ||
| """Checks and prep for FWD. | ||
| The context manager is needed because there isn't a way for a module to know | ||
| if it's the last FP8 module in the forward autocast. It is useful | ||
| to setup the forward aggregated amax reduction for every module | ||
| just in case. The autocast exit will pick up the most recent one. | ||
| """ | ||
| self.allow_different_data_and_param_types = allow_different_data_and_param_types | ||
| self.forwarded_at_least_once = True | ||
| ) -> torch.Tensor: | ||
| """Checks and prepares for FWD execution.""" | ||
| self.fast_setattr( | ||
| "allow_different_data_and_param_types", allow_different_data_and_param_types | ||
| ) | ||
| self.fast_setattr("forwarded_at_least_once", True) | ||
|
|
||
| # Activation recomputation is used and this is the second forward phase. | ||
| if self.fp8 and in_fp8_activation_recompute_phase(): | ||
| delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() | ||
| FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) | ||
| else: | ||
| assert inp.is_cuda, "TransformerEngine needs CUDA." | ||
|
|
||
| if self.tp_size > 1: | ||
| assert self.tp_group_initialized, "TP group not initialized." | ||
|
|
||
| self.set_activation_dtype(inp) | ||
| self.init_fp8_metadata(num_gemms=num_gemms) | ||
| self._check_weight_tensor_recipe_correspondence() | ||
|
|
||
| delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() | ||
| if delayed_scaling_recipe: | ||
| if self.sequence_parallel: | ||
| assert self.fp8_meta["recipe"].reduce_amax, ( | ||
| "Amax reduction across tensor parallel group is " | ||
| "necessary when using sequence parallelism with FP8." | ||
| ) | ||
|
|
||
| if not FP8GlobalStateManager.fp8_graph_capturing(): | ||
| FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) | ||
|
|
||
| # Activation recomputation is used and this is the first forward phase. | ||
| if self.training and is_fp8_activation_recompute_enabled(): | ||
| FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) | ||
|
|
||
| with get_nvtx_range_context(self.__class__.__name__ + " forward"): | ||
| if not allow_non_contiguous and not inp.is_contiguous(): | ||
| inp = inp.contiguous() | ||
| yield inp | ||
| nvtx_range_push(self.__class__.__name__ + " forward") | ||
| if not allow_non_contiguous and not inp.is_contiguous(): | ||
| inp = inp.contiguous() | ||
| return inp | ||
|
|
||
| def end_forward(self): | ||
| """ | ||
| Required to be called at the end of the forward function to properly handle | ||
| DelayedScaling metadata handling and the NVTX ranges. | ||
| """ | ||
| delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() | ||
| if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): | ||
| FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) | ||
| nvtx_range_pop() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions - breaks profiling
The refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. The original context manager pattern guaranteed that nvtx_range_pop() would be called via __exit__ even if an exception occurred.
With the new pattern:
prepare_forward()callsnvtx_range_push()on line 1082end_forward()callsnvtx_range_pop()on line 1095- If ANY exception occurs between these calls,
nvtx_range_pop()is never executed
Impact: This affects Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules. Any exception (shape mismatch, CUDA OOM, assertion failure, invalid tensor, etc.) will leave an unclosed NVTX range, corrupting the profiling stack.
Why this happened: The PR author mentions they avoided touching attention because it has multiple return statements and they were worried about missing something. The same issue exists here - the forward methods have many potential exception points.
Recommendation: Either:
- Keep using context manager (
prepare_forward_ctx) like the attention module does, OR - Wrap the forward logic in try/finally to ensure
end_forward()is always called, OR - Revert this optimization for modules with complex forward methods
| def __setattr__(self, name: str, value: Any) -> None: | ||
| if name in TransformerEngineBaseModule._fast_setattr_names: | ||
| # torch.nn.Module has a custom __setattr__ that handles | ||
| # modules, parameters, and buffers. This is unnecessary | ||
| # overhead when setting plain attrs. | ||
| self.__dict__[name] = value | ||
| else: | ||
| # Default case | ||
| super().__setattr__(name, value) | ||
| if "_initialized" in self.__dict__ and self._initialized: | ||
| warnings.warn( | ||
| """The default implementation of torch.nn.Module introduces significant CPU overhead | ||
| when setting attributes and is therefore not recommended. Please use the explicit | ||
| calls (fast_setattr for setting regular values and module_setattr for setting | ||
| parameters, children modules and buffers).""", | ||
| RuntimeWarning, | ||
| ) | ||
| super().__setattr__(name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: setattr override defeats the optimization it's trying to enable
This __setattr__ override is intended to warn developers when they use the slow path, but it actually creates the performance overhead that the PR aims to eliminate:
- Every attribute assignment (even with
fast_setattr) triggers__setattr__ __setattr__checks"_initialized" in self.__dict__- this is a dictionary lookup- Then it checks
self._initialized- another dictionary lookup - Then it calls
warnings.warn()- string formatting, stack inspection, filtering - Finally calls
super().__setattr__()- the slow PyTorch path
The problem: Even when developers correctly use fast_setattr(), the __setattr__ method is STILL invoked because fast_setattr just does self.__dict__[name] = value, which doesn't trigger __setattr__. But any code that still uses direct assignment (like the bugs found in this PR) will pay a huge penalty.
Why this is worse:
- During
__init__, everyself.x = ynow does 2 dict lookups + warning machinery + slow setattr - The warning itself is expensive (stack inspection, string formatting)
- With
pytest.initreating warnings as errors, this will abort tests
The train() method workaround (lines 1111-1114) shows this was already causing problems - it filters out the warning, but the overhead remains.
Recommendation: Remove the __setattr__ override entirely or only enable it in debug/development mode with an environment variable. Use static analysis tools (linters, mypy plugins) to catch incorrect usage instead of runtime checks.
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | ||
|
|
||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range will not be closed if exception occurs
The manual prepare_forward()/end_forward() pattern breaks exception safety. If ANY exception occurs between line 1402 (prepare_forward() which calls nvtx_range_push) and line 1474 (end_forward() which calls nvtx_range_pop), the NVTX profiling stack will be corrupted.
Potential exception points in this code:
- Line 1404:
_get_weight_and_bias_tensors()could fail - Lines 1406-1414: quantizer construction could fail
- Line 1467:
linear_fn()could raise exceptions (shape mismatches, CUDA OOM, invalid tensors, assertion failures, etc.)
Example failure scenario:
inp = self.prepare_forward(...) # Pushes NVTX range
# ... setup code ...
out = linear_fn(...) # Raises RuntimeError: shape mismatch
# end_forward() is NEVER called, NVTX range never popped
# All subsequent NVTX ranges are corruptedFix: Wrap in try/finally:
inp = self.prepare_forward(...)
try:
# ... all forward logic ...
finally:
self.end_forward()Or use the context manager pattern that attention module correctly uses (prepare_forward_ctx).
Additional Comments (5)
These direct assignments using The
These direct assignments to The
This direct assignment to This line is in the else branch of
This direct assignment to The
This direct assignment to The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 5 comments
Additional Comments (5)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 10 comments
| def __setattr__(self, name: str, value: Any) -> None: | ||
| if name in TransformerEngineBaseModule._fast_setattr_names: | ||
| # torch.nn.Module has a custom __setattr__ that handles | ||
| # modules, parameters, and buffers. This is unnecessary | ||
| # overhead when setting plain attrs. | ||
| self.__dict__[name] = value | ||
| else: | ||
| # Default case | ||
| super().__setattr__(name, value) | ||
| if "_initialized" in self.__dict__ and self._initialized: | ||
| warnings.warn( | ||
| """The default implementation of torch.nn.Module introduces significant CPU overhead | ||
| when setting attributes and is therefore not recommended. Please use the explicit | ||
| calls (fast_setattr for setting regular values and module_setattr for setting | ||
| parameters, children modules and buffers).""", | ||
| RuntimeWarning, | ||
| ) | ||
| super().__setattr__(name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MAJOR ISSUE: The __setattr__ warning defeats the optimization purpose
The __setattr__ override emits a RuntimeWarning whenever it's called after initialization. However, emitting warnings has significant CPU overhead - it involves stack trace inspection, formatting, filtering, etc. This defeats the entire purpose of the optimization!
The warning will fire every time the slow path is used, which is precisely when you're trying to avoid overhead. This makes the "optimization" actually slower than before in cases where the slow path is accidentally used.
Consider one of these alternatives:
- Make
__setattr__raise an exception immediately (fail-fast during development) - Remove the warning entirely and rely on code review to catch issues
- Add a debug mode that can be disabled in production
The current implementation is problematic because:
- Warning overhead during runtime defeats the optimization
- pytest treats it as an error anyway, so it's not really a "warning"
- It will cause legitimate code paths to fail (see direct assignment issues in this file)
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | ||
|
|
||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions will corrupt profiling
The refactoring from context manager to manual prepare_forward()/end_forward() calls breaks exception safety. prepare_forward() calls nvtx_range_push() (line 1082 in base.py), and end_forward() calls nvtx_range_pop() (line 1095).
If any exception occurs between lines 1402-1473 (before end_forward() is called on line 1474), the NVTX range stack will be corrupted because nvtx_range_pop() is never called. This includes:
- Shape mismatches in tensors
- CUDA out of memory errors
- Assertion failures in quantizers
- Type errors in _Linear.apply/forward
- Any exception from user code
The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires wrapping in try-finally:
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | |
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | |
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | |
| quantizers = ( | |
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| if not debug | |
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| ) | |
| if debug: | |
| if self.no_debug_features_active(quantizers): | |
| debug = False | |
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| quantizers = ( | |
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| if not debug | |
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| ) | |
| if debug: | |
| if self.no_debug_features_active(quantizers): | |
| debug = False | |
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| ( | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| ) = quantizers | |
| if is_grad_enabled: | |
| linear_fn = _Linear.apply | |
| autograd_ctx = [] | |
| else: | |
| linear_fn = _Linear.forward | |
| autograd_ctx = [None] | |
| non_tensor_args = ( | |
| is_first_microbatch, | |
| self.fp8, | |
| self.fp8_calibration, | |
| self.wgrad_store, | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| self.fuse_wgrad_accumulation, | |
| is_cpu_offload_enabled(), | |
| self.tp_group, | |
| self.tp_size, | |
| self.sequence_parallel, | |
| self.tp_size > 1, | |
| self.activation_dtype, | |
| self.parallel_mode, | |
| is_grad_enabled, | |
| self.ub_overlap_rs_fprop, | |
| self.ub_overlap_ag_dgrad, | |
| self.ub_overlap_ag_fprop, | |
| self.ub_overlap_rs_dgrad, | |
| self.ub_bulk_dgrad, | |
| self.ub_bulk_wgrad, | |
| self.ub_name, | |
| fp8_output, | |
| self.fsdp_group, | |
| self, | |
| skip_fp8_weight_update, | |
| self.symmetric_ar_type, | |
| self.save_original_input, | |
| debug, | |
| ) | |
| out = linear_fn( | |
| *autograd_ctx, | |
| weight_tensor, | |
| inp, | |
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | |
| non_tensor_args, | |
| ) | |
| ( | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| ) = quantizers | |
| if is_grad_enabled: | |
| linear_fn = _Linear.apply | |
| autograd_ctx = [] | |
| else: | |
| linear_fn = _Linear.forward | |
| autograd_ctx = [None] | |
| non_tensor_args = ( | |
| is_first_microbatch, | |
| self.fp8, | |
| self.fp8_calibration, | |
| self.wgrad_store, | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| self.fuse_wgrad_accumulation, | |
| is_cpu_offload_enabled(), | |
| self.tp_group, | |
| self.tp_size, | |
| self.sequence_parallel, | |
| self.tp_size > 1, | |
| self.activation_dtype, | |
| self.parallel_mode, | |
| is_grad_enabled, | |
| self.ub_overlap_rs_fprop, | |
| self.ub_overlap_ag_dgrad, | |
| self.ub_overlap_ag_fprop, | |
| self.ub_overlap_rs_dgrad, | |
| self.ub_bulk_dgrad, | |
| self.ub_bulk_wgrad, | |
| self.ub_name, | |
| fp8_output, | |
| self.fsdp_group, | |
| self, | |
| skip_fp8_weight_update, | |
| self.symmetric_ar_type, | |
| self.save_original_input, | |
| debug, | |
| ) | |
| out = linear_fn( | |
| *autograd_ctx, | |
| weight_tensor, | |
| inp, | |
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | |
| non_tensor_args, | |
| ) | |
| self.end_forward() | |
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | |
| try: | |
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | |
| quantizers = ( | |
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| if not debug | |
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| ) | |
| if debug: | |
| if self.no_debug_features_active(quantizers): | |
| debug = False | |
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| ( | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| ) = quantizers | |
| if is_grad_enabled: | |
| linear_fn = _Linear.apply | |
| autograd_ctx = [] | |
| else: | |
| linear_fn = _Linear.forward | |
| autograd_ctx = [None] | |
| non_tensor_args = ( | |
| is_first_microbatch, | |
| self.fp8, | |
| self.fp8_calibration, | |
| self.wgrad_store, | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| self.fuse_wgrad_accumulation, | |
| is_cpu_offload_enabled(), | |
| self.tp_group, | |
| self.tp_size, | |
| self.sequence_parallel, | |
| self.tp_size > 1, | |
| self.activation_dtype, | |
| self.parallel_mode, | |
| is_grad_enabled, | |
| self.ub_overlap_rs_fprop, | |
| self.ub_overlap_ag_dgrad, | |
| self.ub_overlap_ag_fprop, | |
| self.ub_overlap_rs_dgrad, | |
| self.ub_bulk_dgrad, | |
| self.ub_bulk_wgrad, | |
| self.ub_name, | |
| fp8_output, | |
| self.fsdp_group, | |
| self, | |
| skip_fp8_weight_update, | |
| self.symmetric_ar_type, | |
| self.save_original_input, | |
| debug, | |
| ) | |
| out = linear_fn( | |
| *autograd_ctx, | |
| weight_tensor, | |
| inp, | |
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | |
| non_tensor_args, | |
| ) | |
| finally: | |
| self.end_forward() |
This same issue affects LayerNormLinear, LayerNormMLP, and GroupedLinear.
| inp = self.prepare_forward( | ||
| inp, allow_non_contiguous=False # removed .contiguous from inside the layer | ||
| ) as inp: | ||
| ) | ||
|
|
||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions
Same exception safety issue as in Linear module. If any exception occurs between prepare_forward() and end_forward(), the NVTX range stack will be corrupted. This code needs try-finally wrapping:
| inp = self.prepare_forward( | |
| inp, allow_non_contiguous=False # removed .contiguous from inside the layer | |
| ) as inp: | |
| ) | |
| # Get concatenated weight and bias tensors | |
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | |
| # Get concatenated weight and bias tensors | |
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | |
| quantizers = ( | |
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| if not debug | |
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| ) | |
| if debug: | |
| if self.no_debug_features_active(quantizers): | |
| debug = False | |
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| quantizers = ( | |
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| if not debug | |
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| ) | |
| if debug: | |
| if self.no_debug_features_active(quantizers): | |
| debug = False | |
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | |
| ( | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| ) = quantizers | |
| if is_grad_enabled: | |
| fwd_fn = _LayerNormLinear.apply | |
| autograd_ctx = [] | |
| else: | |
| fwd_fn = _LayerNormLinear.forward | |
| autograd_ctx = [None] | |
| non_tensor_args = ( | |
| self.eps, | |
| is_first_microbatch, | |
| self.fp8, | |
| self.fp8_calibration, | |
| self.wgrad_store, | |
| self.fuse_wgrad_accumulation, | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| is_cpu_offload_enabled(), | |
| self.tp_group, | |
| self.tp_size, | |
| self.sequence_parallel, | |
| self.tp_size > 1, | |
| self.activation_dtype, | |
| self.parallel_mode, | |
| self.return_layernorm_output, | |
| self.return_layernorm_output_gathered, | |
| is_grad_enabled, | |
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | |
| self.bwd_ln_sm_margin, | |
| self.zero_centered_gamma, | |
| self.normalization, | |
| self.ub_overlap_ag_fprop, | |
| self.ub_overlap_rs_fprop, | |
| self.ub_overlap_ag_dgrad, | |
| self.ub_overlap_rs_dgrad, | |
| self.ub_bulk_wgrad, | |
| self.ub_bulk_dgrad, | |
| self.ub_name, | |
| self.fsdp_group, | |
| self, | |
| skip_fp8_weight_update, | |
| self.symmetric_ar_type, | |
| debug, | |
| ) | |
| out = fwd_fn( | |
| *autograd_ctx, | |
| inp, | |
| self.layer_norm_weight, | |
| self.layer_norm_bias, | |
| weight_tensor, | |
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | |
| non_tensor_args, | |
| ) | |
| ( | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| ) = quantizers | |
| if is_grad_enabled: | |
| fwd_fn = _LayerNormLinear.apply | |
| autograd_ctx = [] | |
| else: | |
| fwd_fn = _LayerNormLinear.forward | |
| autograd_ctx = [None] | |
| non_tensor_args = ( | |
| self.eps, | |
| is_first_microbatch, | |
| self.fp8, | |
| self.fp8_calibration, | |
| self.wgrad_store, | |
| self.fuse_wgrad_accumulation, | |
| input_quantizer, | |
| weight_quantizer, | |
| output_quantizer, | |
| grad_input_quantizer, | |
| grad_weight_quantizer, | |
| grad_output_quantizer, | |
| is_cpu_offload_enabled(), | |
| self.tp_group, | |
| self.tp_size, | |
| self.sequence_parallel, | |
| self.tp_size > 1, | |
| self.activation_dtype, | |
| self.parallel_mode, | |
| self.return_layernorm_output, | |
| self.return_layernorm_output_gathered, | |
| is_grad_enabled, | |
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | |
| self.bwd_ln_sm_margin, | |
| self.zero_centered_gamma, | |
| self.normalization, | |
| self.ub_overlap_ag_fprop, | |
| self.ub_overlap_rs_fprop, | |
| self.ub_overlap_ag_dgrad, | |
| self.ub_overlap_rs_dgrad, | |
| self.ub_bulk_wgrad, | |
| self.ub_bulk_dgrad, | |
| self.ub_name, | |
| self.fsdp_group, | |
| self, | |
| skip_fp8_weight_update, | |
| self.symmetric_ar_type, | |
| debug, | |
| ) | |
| out = fwd_fn( | |
| *autograd_ctx, | |
| inp, | |
| self.layer_norm_weight, | |
| self.layer_norm_bias, | |
| weight_tensor, | |
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | |
| non_tensor_args, | |
| ) | |
| inp = self.prepare_forward( | |
| inp, allow_non_contiguous=False # removed .contiguous from inside the layer | |
| ) | |
| try: | |
| # Get concatenated weight and bias tensors | |
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | |
| # ... rest of the forward code ... | |
| finally: | |
| self.end_forward() |
| inp = self.prepare_forward(inp, num_gemms=2) | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
|
|
||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.bias_gelu_nvfusion = False | ||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.fast_setattr("bias_gelu_nvfusion", False) | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions
Same exception safety issue as in Linear module. The code between prepare_forward() and end_forward() needs try-finally wrapping to ensure NVTX ranges are properly closed even when exceptions occur.
| inp = self.prepare_forward(inp, num_gemms=self.num_gemms) | ||
| weight_tensors = self._get_weight_tensors() | ||
| bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] | ||
|
|
||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
|
|
||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
|
|
||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
|
|
||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions
Same exception safety issue as in Linear module. The code between prepare_forward() and end_forward() needs try-finally wrapping to ensure NVTX ranges are properly closed even when exceptions occur.
Additional Comments (5)
These direct assignments using The
These direct assignments to
This direct assignment to
This direct assignment to
This direct assignment to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 7 comments
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | ||
|
|
||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions
The refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (line 1402, which calls nvtx_range_push) and end_forward() (line 1474, which calls nvtx_range_pop), the NVTX range stack will be corrupted.
Exceptions can occur in many places:
_get_weight_and_bias_tensors()(line 1404)_get_quantizers()or_get_debug_quantizers()(lines 1406-1414)linear_fn()(line 1467) - shape mismatches, CUDA OOM, assertions, etc.
The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires all code paths to call end_forward(), but exceptions bypass this, leaving an unclosed NVTX range that corrupts the stack for all subsequent operations.
Consider using a try-finally block or reverting to the context manager pattern for this module.
| inp = self.prepare_forward( | ||
| inp, allow_non_contiguous=False # removed .contiguous from inside the layer | ||
| ) as inp: | ||
| ) | ||
|
|
||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions
Same issue as in Linear module: the refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (line 1519) and end_forward() (line 1596), the NVTX range stack will be corrupted.
Exceptions can occur in:
_get_weight_and_bias_tensors()(line 1522)_get_quantizers()or_get_debug_quantizers()(lines 1524-1532)fwd_fn()(line 1586) - shape mismatches, CUDA OOM, assertions, etc.
Consider using a try-finally block or reverting to the context manager pattern.
| inp = self.prepare_forward(inp, num_gemms=2) | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
|
|
||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.bias_gelu_nvfusion = False | ||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.fast_setattr("bias_gelu_nvfusion", False) | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions
Same issue as in Linear module: the refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (line 2057) and end_forward() (line 2157), the NVTX range stack will be corrupted.
Exceptions can occur in:
_get_quantizers()or_get_debug_quantizers()(lines 2059-2067)_get_weight_tensors()(line 2075)fwd_fn()(line 2145) - shape mismatches, CUDA OOM, assertions, etc.
Consider using a try-finally block or reverting to the context manager pattern.
| inp = self.prepare_forward(inp, num_gemms=self.num_gemms) | ||
| weight_tensors = self._get_weight_tensors() | ||
| bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] | ||
|
|
||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
|
|
||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
|
|
||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
|
|
||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: NVTX range imbalance on exceptions
Same issue as in Linear module: the refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (line 793) and end_forward() (line 847), the NVTX range stack will be corrupted.
Exceptions can occur in:
_get_weight_tensors()(line 794)_get_quantizers()or_get_debug_quantizers()(lines 797-805)linear_fn()(line 846) - shape mismatches, CUDA OOM, assertions, etc.
Consider using a try-finally block or reverting to the context manager pattern.
Additional Comments (3)
These direct assignments using The
These direct assignments to The
This direct assignment to |
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
f96abbd to
2fb6ee3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/module/base.py, line 1570 (link)syntax: CRITICAL: RuntimeWarning will cause test failures
This direct attribute assignment will also trigger the RuntimeWarning.
17 files reviewed, 1 comment
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
for more information, see https://pre-commit.ci
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
Signed-off-by: Przemek Tredak <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Przemek Tredak <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/module/base.py, line 967-968 (link)logic:
set_tensor_parallel_groupuses direct attribute assignment instead offast_setattr. This public method is documented to be called after module initialization, which will trigger RuntimeWarning (now treated as error by pytest.ini).
19 files reviewed, 1 comment
Signed-off-by: Przemek Tredak <[email protected]>
Description
This PR includes a few performance optimizations targeting the CPU overhead. The code, perf numbers etc. are WIP. The code gets kind of ugly though :-(.
For the prepare_forward changes I did not touch attention (@cyanguwa FYI) since it has multiple exit points from the forward and was worried that I would miss something there - it would be great if we could refactor that part first to have a single return statement instead.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: