Skip to content

Commit 5685b2c

Browse files
sudhu2ksudhu2k
andauthored
Te2.4 fsdp2 fp8 allgather autocast (#349)
* Initial commit * Removed Print statements, added keep_fp8_transpose cache integration with fsdp2 * Added use_fsdp flag to Linear module, added profile code, added test code, added all reduce for amax * Fixed unit test * Removing all reduce code for amax since by default TE does all reduce when torch.distributed is initialized. * reverting case where out is already present * Added unit test with regualr sgpu training * Modified unit test to compare FSDP2 with DDP * bug fixes * Code cleaning up * Initial commit to add MXFP8 * Added fp8 current scaling. * Added MXFP8, Modified unit test to run based on recipes * Extended use_fsdp to layernorm linear and layernorm mlp * Moved amax reduce from forward to backward for fsdp2 * Added automatic detection of use fsdp from base module * Use SKIP_FP8_REDUCTION_FOR_FSDP2 in backward for check if need to do forward reduce * Added memory profile code, added a check before setting SKIP_FP8_REDUCTION_FOR_FSDP2 * Fix for fused optimizer, changed _elem to _data, code clean up * Fixed layernorm mlp * Code cleanup and added test to pytorch.sh * Removed whitespaces * Fixed comments and license * Added guards * Added reduce for forward in cuda graph backward, added code to remove test artifacts, reverted upstream test file --------- Co-authored-by: sudhu2k <[email protected]>
1 parent 32e2d1d commit 5685b2c

File tree

11 files changed

+688
-13
lines changed

11 files changed

+688
-13
lines changed

ci/pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ run_test_config_mgpu(){
9393
run 3 distributed/test_fusible_ops.py
9494
run 3 distributed/test_numerics.py
9595
run 3 distributed/test_torch_fsdp2.py
96+
run 3 distributed/test_torch_fsdp2_fp8.py
9697
run 3 fused_attn/test_fused_attn_with_cp.py
9798
fi
9899
}
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
#!/usr/bin/python3
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
# See LICENSE for license information.
4+
5+
6+
import os
7+
import sys
8+
import argparse
9+
10+
import transformer_engine.pytorch as te
11+
from transformer_engine.common.recipe import Float8CurrentScaling, Format, DelayedScaling, MXFP8BlockScaling
12+
13+
import torch
14+
import torch.distributed as dist
15+
import torch.nn.functional as F
16+
from torch import nn, optim
17+
from torch.distributed import DeviceMesh
18+
from torch.distributed._composable.fsdp import fully_shard
19+
from torch.distributed.device_mesh import init_device_mesh
20+
from transformer_engine.pytorch import torch_version
21+
from transformer_engine.pytorch.fp8 import fp8_model_init
22+
from torch.nn.parallel import DistributedDataParallel as DDP
23+
from pathlib import Path
24+
25+
class SimpleNet(nn.Module):
26+
def __init__(self, input_size, hidden_size, output_size, use_fsdp2=False):
27+
super(SimpleNet, self).__init__()
28+
29+
# LayerNormLinear: fuses LayerNorm + Linear
30+
self.ln_linear = te.LayerNormLinear(
31+
in_features=input_size,
32+
out_features=hidden_size,
33+
eps=1e-5,
34+
use_fsdp2=use_fsdp2,
35+
keep_fp8_weight_transpose_cache=False
36+
)
37+
38+
# LayerNormMLP: fuses LayerNorm + FC1 + Activation + FC2
39+
self.ln_mlp = te.LayerNormMLP(
40+
hidden_size=hidden_size,
41+
ffn_hidden_size=hidden_size * 4, # Typical 4x expansion
42+
use_fsdp2=use_fsdp2,
43+
keep_fp8_weight_transpose_cache=False
44+
)
45+
46+
# Regular Linear for final projection
47+
self.fc_out = te.Linear(
48+
hidden_size,
49+
output_size,
50+
use_fsdp2=use_fsdp2,
51+
keep_fp8_weight_transpose_cache=False
52+
)
53+
54+
def forward(self, x):
55+
# LayerNormLinear: applies LayerNorm then Linear
56+
x = self.ln_linear(x)
57+
58+
# LayerNormMLP: applies LayerNorm + FC1 + GELU + FC2
59+
x = self.ln_mlp(x)
60+
61+
# Final Linear projection
62+
x = self.fc_out(x)
63+
64+
return x
65+
66+
def save_custom_attrs(module, _SKIP_KEYS = {"_data", "_module", "_transpose"}):
67+
custom_attrs = {}
68+
for name, param in module.named_parameters():
69+
attrs = vars(param)
70+
custom_attrs[name] = {k: v for k, v in attrs.items()}
71+
for k in _SKIP_KEYS:
72+
custom_attrs[name].pop(k, None)
73+
return custom_attrs
74+
75+
76+
def restore_custom_attrs(module, custom_attrs):
77+
for name, param in module.named_parameters():
78+
if name in custom_attrs:
79+
for attr_name, attr_value in custom_attrs[name].items():
80+
setattr(param, attr_name, attr_value)
81+
82+
83+
def _parse_args(argv=None, namespace=None):
84+
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
85+
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
86+
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
87+
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
88+
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
89+
parser.add_argument(
90+
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
91+
)
92+
parser.add_argument(
93+
"--iter", type=int, default=10, help="Number of iterations for forward pass"
94+
)
95+
parser.add_argument('--profile', action='store_true',
96+
help='Enable pytorch profiling.')
97+
parser.add_argument('--profile-step-start', type=int, default=6,
98+
help='Global step to start profiling.')
99+
parser.add_argument('--profile-step-end', type=int, default=7,
100+
help='Global step to stop profiling.')
101+
parser.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
102+
help='Global ranks to profile.')
103+
parser.add_argument('--tensorboard-dir', type=str, default='./fsdp2_tensorboard',
104+
help='Write TensorBoard logs to this directory.')
105+
parser.add_argument('--gradients-save-file', type=str, default='all_iters.pt',
106+
help='Write all the gradients across all the iterations to this file.')
107+
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
108+
parser.add_argument("--use-fsdp2", action='store_true',
109+
help='Enable New FSDP2 training.')
110+
parser.add_argument("--memory-profile", action='store_true',
111+
help='profile memory traces')
112+
parser.add_argument(
113+
"--recipe",
114+
type=str,
115+
choices=["delayed", "mxfp8", "current"],
116+
default="delayed",
117+
help="Select the training recipe to use: 'delayed', 'mxfp8', or 'current'."
118+
)
119+
120+
# Adding hsdp_dim as a list argument, comma-separated
121+
parser.add_argument(
122+
"--sharding-dims",
123+
type=int,
124+
nargs="+",
125+
help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
126+
)
127+
args = parser.parse_args(argv, namespace)
128+
if args.sharding_dims:
129+
assert len(args.sharding_dims) <= 2
130+
return args
131+
132+
133+
sub_modules_to_wrap = [te.Linear, te.LayerNormLinear, te.LayerNormMLP]
134+
135+
136+
def _train(args):
137+
assert "TORCHELASTIC_RUN_ID" in os.environ
138+
WORLD_RANK = int(os.getenv("RANK", "0"))
139+
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
140+
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
141+
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
142+
assert LOCAL_SIZE == WORLD_SIZE
143+
144+
# Set device and initialize RNG states
145+
torch.cuda.set_device(WORLD_RANK)
146+
torch.manual_seed(args.seed)
147+
torch.cuda.manual_seed(args.seed)
148+
149+
# Initialize torch.distributed global process group and get DP/TP groups
150+
dist_init_kwargs = {
151+
"backend": "nccl",
152+
"rank": WORLD_RANK,
153+
"world_size": WORLD_SIZE,
154+
}
155+
assert dist.is_nccl_available()
156+
dist.init_process_group(**dist_init_kwargs)
157+
nccl_world = dist.new_group(backend="nccl")
158+
device = torch.device(f"cuda:{LOCAL_RANK}")
159+
160+
# FP8 Configuration
161+
if args.recipe == "current":
162+
fp8_recipe = Float8CurrentScaling()
163+
elif args.recipe == "mxfp8":
164+
fp8_recipe = MXFP8BlockScaling()
165+
elif args.recipe == "delayed":
166+
fp8_recipe = DelayedScaling()
167+
else:
168+
raise ValueError(f"Unsupported recipe: {args.recipe}")
169+
170+
if args.memory_profile:
171+
torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all')
172+
if args.fp8_init:
173+
# Build the model with the specified context
174+
with fp8_model_init(enabled = True):
175+
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
176+
else:
177+
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
178+
# Move the model to the correct device
179+
if not args.memory_profile:
180+
model.load_state_dict(torch.load('fsdp_model.pth'))
181+
model.to(device)
182+
183+
# Creating a DeviceMesh for fully_shard
184+
world_size = int(WORLD_SIZE)
185+
device_ids = list(range(world_size))
186+
187+
# Apply FSDP/HSDP
188+
if args.use_fsdp2:
189+
custom_attrs = save_custom_attrs(model)
190+
if LOCAL_RANK == 0:
191+
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
192+
print(f"sharding-dims:{args.sharding_dims}")
193+
# Setup the sharding mesh for FSDP/HSDP
194+
if args.sharding_dims == None: # FSDP
195+
mesh = DeviceMesh("cuda", device_ids)
196+
elif len(args.sharding_dims) == 1:
197+
assert args.sharding_dims[0] == device_ids[-1] + 1
198+
mesh = DeviceMesh("cuda", device_ids)
199+
elif len(args.sharding_dims) == 2: # HSDP
200+
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
201+
mesh = init_device_mesh(
202+
"cuda",
203+
(args.sharding_dims[0], args.sharding_dims[1]),
204+
mesh_dim_names=("replicate", "shard"),
205+
)
206+
else:
207+
assert False
208+
for sub_module in model.modules():
209+
if any(
210+
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
211+
):
212+
fully_shard(sub_module, mesh=mesh)
213+
fully_shard(model, mesh=mesh, reshard_after_forward=True)
214+
restore_custom_attrs(model, custom_attrs)
215+
else:
216+
model = DDP(model, device_ids=[LOCAL_RANK])
217+
218+
optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3)
219+
220+
input_path = Path("shared_input.pt")
221+
if input_path.exists():
222+
input_data = torch.load(input_path).to(device)
223+
else:
224+
input_data = torch.randn(args.batch_size, args.input_size, requires_grad=True).to(device)
225+
torch.save(input_data.cpu(), input_path)
226+
print("Generated and saved shared input tensor.")
227+
228+
out_tensors = []
229+
prof = None
230+
if (
231+
args.profile
232+
and torch.distributed.get_rank() in args.profile_ranks
233+
):
234+
prof = torch.profiler.profile(
235+
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
236+
schedule=torch.profiler.schedule(
237+
wait=max(args.profile_step_start - 1, 0),
238+
warmup=1 if args.profile_step_start > 0 else 0,
239+
active=args.profile_step_end - args.profile_step_start,
240+
repeat=1,
241+
),
242+
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
243+
record_shapes=True,
244+
profile_memory=True,
245+
with_stack=True,
246+
)
247+
prof.start()
248+
for iteration in range(args.iter):
249+
if LOCAL_RANK == 0:
250+
print(f"Starting iteration...{iteration}")
251+
if args.profile and torch.distributed.get_rank() in args.profile_ranks:
252+
prof.step()
253+
254+
# Zero the parameter gradients
255+
optimizer.zero_grad()
256+
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
257+
output = model(input_data)
258+
target = torch.randn(args.batch_size, args.output_size).to(device)
259+
loss = F.mse_loss(output, target)
260+
loss.backward()
261+
optimizer.step()
262+
if LOCAL_RANK == 0:
263+
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
264+
265+
if not args.profile and not args.memory_profile:
266+
with torch.no_grad():
267+
for name, p in model.named_parameters():
268+
full_grad = None
269+
if p.grad is not None and hasattr(p.grad, 'full_tensor'):
270+
# This call is required to be executed on ALL ranks
271+
# to complete the collective communication.
272+
full_grad = p.grad.full_tensor().detach().clone()
273+
elif p.grad is not None:
274+
full_grad = p.grad.detach().clone()
275+
# 2. Only Rank 0 stores the result
276+
if LOCAL_RANK == 0 and p.requires_grad:
277+
out_tensors.append((name, full_grad))
278+
if (
279+
args.profile
280+
and iteration == args.profile_step_end
281+
and torch.distributed.get_rank() in args.profile_ranks
282+
):
283+
prof.stop()
284+
285+
if (not args.profile and not args.memory_profile) and LOCAL_RANK == 0:
286+
torch.save(out_tensors, args.gradients_save_file)
287+
288+
if args.memory_profile:
289+
snapshot = torch.cuda.memory._snapshot()
290+
import pickle
291+
with open('memory_snapshot.pickle', 'wb') as f:
292+
pickle.dump(snapshot, f)
293+
# To disable memory history recording when no longer needed
294+
torch.cuda.memory._record_memory_history(enabled=None)
295+
296+
# NOTE: In PyTorch < 2.6 there’s a teardown race where one rank may call
297+
# destroy_process_group() while other ranks still have in-flight NCCL ops,
298+
# which can trigger a NCCL/RCCL comm error. Newer releases (>= 2.6) fixed
299+
# this, but we kept a version-guarded barrier on older Torch for stability.
300+
if torch_version() < (2, 6, 0):
301+
dist.barrier(device_ids=[torch.cuda.current_device()])
302+
dist.destroy_process_group()
303+
304+
return 0
305+
306+
307+
if __name__ == "__main__":
308+
sys.exit(_train(_parse_args()))

0 commit comments

Comments
 (0)