Skip to content

Commit 80d2d1c

Browse files
authored
Merge branch 'dev' into gdn_thd
2 parents e8ed23c + f6f2abe commit 80d2d1c

File tree

3 files changed

+65
-8
lines changed

3 files changed

+65
-8
lines changed

megatron/core/distributed/param_and_grad_buffer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ class _ParamAndGradBucket:
7878
communication. Its application is twofold: it facilitates the averaging of gradients
7979
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
8080
bucket_id: Index of bucket in buffer.
81+
param_index_map: Mapping from param to (start, end, bucket_id) in the global buffer.
82+
Used to derive bucket-local offsets for param_to_index.
8183
"""
8284

8385
def __init__(
@@ -89,6 +91,7 @@ def __init__(
8991
numel_unpadded: int,
9092
gradient_scaling_factor: float,
9193
bucket_id: int,
94+
param_index_map: Dict[torch.nn.Parameter, tuple],
9295
):
9396
self.params_list = params
9497
self.params = set(params)
@@ -102,11 +105,11 @@ def __init__(
102105
self.numel_unpadded = numel_unpadded
103106
self.gradient_scaling_factor = gradient_scaling_factor
104107
self.bucket_id = bucket_id
108+
# Derive bucket-local param offsets from the global param_index_map.
105109
self.param_to_index = {}
106-
offset = 0
107110
for param in params:
108-
self.param_to_index[param] = (offset, offset + param.numel())
109-
offset += param.numel()
111+
global_start, global_end, _ = param_index_map[param]
112+
self.param_to_index[param] = (global_start - offset, global_end - offset)
110113

111114

112115
class _ParamAndGradBucketGroup:
@@ -926,6 +929,7 @@ def _new_bucket(
926929
numel_unpadded=numel_unpadded,
927930
gradient_scaling_factor=self.gradient_scaling_factor,
928931
bucket_id=bucket_id,
932+
param_index_map=self.param_index_map,
929933
)
930934
for bucket_param in bucket_params:
931935
assert bucket_param not in self.param_to_bucket

tests/test_utils/recipes/mamba-static-inference.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ products:
5959
- environment: [dev]
6060
scope: [mr-broken, mr-github-broken]
6161
platforms: [dgx_h100]
62-
- test_case: [hybrid_static_inference_tp1_pp1_2B_cudagraphs]
63-
products:
64-
- environment: [dev]
65-
scope: [mr]
66-
platforms: [dgx_h100]
62+
# - test_case: [hybrid_static_inference_tp1_pp1_2B_cudagraphs]
63+
# products:
64+
# - environment: [dev]
65+
# scope: [mr]
66+
# platforms: [dgx_h100] # Broken after dev2main sync 01/27

tests/unit_tests/distributed/test_param_and_grad_buffer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,59 @@ def _pad_param_if_needed(numel_unpadded):
162162
Utils.destroy_model_parallel()
163163

164164

165+
def test_param_to_index_alignment_with_padding():
166+
"""Ensure bucket-local param offsets honor padding when DistOpt pads params."""
167+
Utils.initialize_model_parallel()
168+
169+
# With input_dim=4, output_dim=4:
170+
# - weight: 4*4 = 16 elements
171+
# - bias: 4 elements
172+
# Since 16 % 64 != 0, the bias must be padded away from the weight,
173+
# making padding observable.
174+
input_dim = 4
175+
output_dim = 4
176+
model, param_and_grad_buffer, _ = get_model_and_buffers(
177+
input_dim=input_dim,
178+
output_dim=output_dim,
179+
num_layers=1,
180+
bias=True,
181+
shared_embedding=False,
182+
bucket_size=None, # single bucket
183+
use_distributed_optimizer=True, # enforces 64-element alignment
184+
overlap_grad_reduce=True,
185+
average_in_collective=False,
186+
)
187+
188+
bucket = param_and_grad_buffer.buckets[0]
189+
naive_offset = 0
190+
padding_observed = False
191+
192+
for param in bucket.params_list:
193+
global_start, global_end, _ = param_and_grad_buffer.param_index_map[param]
194+
expected_local_start = global_start - bucket.offset
195+
expected_local_end = global_end - bucket.offset
196+
local_start, local_end = bucket.param_to_index[param]
197+
198+
# param_to_index should match the padded offsets used in the global buffer.
199+
assert (local_start, local_end) == (expected_local_start, expected_local_end)
200+
201+
# At least one param should have been padded relative to naive packing.
202+
if local_start != naive_offset:
203+
padding_observed = True
204+
naive_offset = local_end
205+
206+
# Verify the slice retrieved via param_to_index matches param.data view.
207+
param_slice = bucket.param_data.view(-1)[local_start:local_end]
208+
torch.testing.assert_close(param_slice, param.data.view(-1))
209+
210+
assert padding_observed, (
211+
"Expected padding to be applied between params. "
212+
"Ensure model dimensions are chosen such that param sizes are not multiples of 64."
213+
)
214+
215+
Utils.destroy_model_parallel()
216+
217+
165218
@pytest.mark.parametrize("use_distributed_optimizer", [False, True])
166219
@pytest.mark.parametrize("overlap_grad_reduce", [False, True])
167220
@pytest.mark.parametrize("average_in_collective", [False, True])

0 commit comments

Comments
 (0)