-
Notifications
You must be signed in to change notification settings - Fork 607
[PyTorch] Remove unnecessary save of weights #2549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR successfully refactors weight tensor handling across four autograd Function classes ( Motivation: Implementation Pattern: Forward Pass:
Backward Pass:
Key Changes:
Correctness:
All attribute access uses the correct object type. Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant FWD as Forward Pass
participant CTX as AutogradContext
participant WEAKREF as WeakReference
participant WEIGHT as Weight Parameter
participant BWD as Backward Pass
participant MCORE as MCore DDP
FWD->>WEIGHT: Check fuse_wgrad_accumulation & requires_grad
FWD->>WEAKREF: Create weakref.ref(weight)
FWD->>CTX: Save origin_weight_ref
FWD->>CTX: Save origin_weight_overwrites_main_grad flag
FWD->>CTX: save_for_backward(weightmat) [tensor data only]
Note over FWD,BWD: Forward/Backward boundary
BWD->>CTX: Retrieve origin_weight_ref
BWD->>WEAKREF: Dereference weakref()
WEAKREF-->>BWD: Return original Python object
BWD->>BWD: Assert origin_weight is not None
BWD->>WEIGHT: Set origin_weight.main_grad
BWD->>MCORE: Check hasattr(origin_weight, "grad_added_to_main_grad")
MCORE-->>BWD: Attribute exists
BWD->>WEIGHT: Set origin_weight.grad_added_to_main_grad = True
BWD->>BWD: Return dummy wgrad (gradient handled via main_grad)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 7 comments
| "out_dtype": ( | ||
| origin_fc2_weight.main_grad.dtype | ||
| fc2_weight_main_grad.dtype | ||
| if ctx.fuse_wgrad_accumulation | ||
| else ctx.activation_dtype | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When ctx.fuse_wgrad_accumulation is True but the weakref returns None (line 996-999), fc2_weight_main_grad remains None (check fails at line 1004). Accessing .dtype on None will cause an AttributeError. Need to add a safety check or ensure fc2_weight_main_grad has a fallback value.
Additional Comments (2)
Should use:
Should use: |
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
| fc2_wgrad_gemm_kwargs = { | ||
| "out_dtype": ( | ||
| origin_fc2_weight.main_grad.dtype | ||
| fc2_weight_main_grad.dtype | ||
| if ctx.fuse_wgrad_accumulation | ||
| else ctx.activation_dtype | ||
| ), | ||
| "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision | ||
| "accumulate": ( | ||
| accumulate_wgrad_into_param_main_grad | ||
| if not getattr(fc1_weight, "overwrite_main_grad", False) | ||
| if not getattr(ctx, "fc2_weight_overwrites_main_grad", False) | ||
| else False | ||
| ), | ||
| "layout": "NT", | ||
| "out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, | ||
| "out": fc2_weight_main_grad if ctx.fuse_wgrad_accumulation else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] fc2_weight_main_grad is None when fuse_wgrad_accumulation=False, causing AttributeError on .dtype access. This block executes when ctx.fc2_weight_requires_grad=True (line 1178), but fc2_weight_main_grad is only initialized when both ctx.fuse_wgrad_accumulation=True AND ctx.fc2_weight_requires_grad=True (lines 993-1018). Wrap this kwargs dict in if ctx.fuse_wgrad_accumulation: or use ternary to handle None case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but .dtype is invoked only if "if ctx.fuse_wgrad_accumulation" returns True ....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Description
MCore's fused wgrad accumulation feature requires setting the
grad_added_to_main_gradattribute on the weight's Python object. This means the original Python object must be accessible and modifiable during the backward pass.Currently, weights are saved via
save_for_backward, with the assumption that no hooks substitute them with different tensors (e.g., during CPU offload/reload). For CPU offloading, we work around this by saving weights directly onctx. However, this approach is incompatible with non-TE CPU offloading scenarios and potentially conflicts with FSDP, which also manages weight tensors.This PR addresses these issues by saving weak references to weights for the backward pass instead. When modifications to the original Python object are needed (e.g., setting
grad_added_to_main_grad), the weakref is dereferenced and the modification is applied. This is done conditionally, only when MCore FSDP or MCore fused wgrad accumulation is enabled.Changes:
weakrefin forward passfuse_wgrad_accumulationis enabledctxlinear.py,layernorm_linear.py,grouped_linear.py, andlayernorm_mlp.pyType of change
Checklist: