@@ -162,6 +162,59 @@ def _pad_param_if_needed(numel_unpadded):
162162 Utils .destroy_model_parallel ()
163163
164164
165+ def test_param_to_index_alignment_with_padding ():
166+ """Ensure bucket-local param offsets honor padding when DistOpt pads params."""
167+ Utils .initialize_model_parallel ()
168+
169+ # With input_dim=4, output_dim=4:
170+ # - weight: 4*4 = 16 elements
171+ # - bias: 4 elements
172+ # Since 16 % 64 != 0, the bias must be padded away from the weight,
173+ # making padding observable.
174+ input_dim = 4
175+ output_dim = 4
176+ model , param_and_grad_buffer , _ = get_model_and_buffers (
177+ input_dim = input_dim ,
178+ output_dim = output_dim ,
179+ num_layers = 1 ,
180+ bias = True ,
181+ shared_embedding = False ,
182+ bucket_size = None , # single bucket
183+ use_distributed_optimizer = True , # enforces 64-element alignment
184+ overlap_grad_reduce = True ,
185+ average_in_collective = False ,
186+ )
187+
188+ bucket = param_and_grad_buffer .buckets [0 ]
189+ naive_offset = 0
190+ padding_observed = False
191+
192+ for param in bucket .params_list :
193+ global_start , global_end , _ = param_and_grad_buffer .param_index_map [param ]
194+ expected_local_start = global_start - bucket .offset
195+ expected_local_end = global_end - bucket .offset
196+ local_start , local_end = bucket .param_to_index [param ]
197+
198+ # param_to_index should match the padded offsets used in the global buffer.
199+ assert (local_start , local_end ) == (expected_local_start , expected_local_end )
200+
201+ # At least one param should have been padded relative to naive packing.
202+ if local_start != naive_offset :
203+ padding_observed = True
204+ naive_offset = local_end
205+
206+ # Verify the slice retrieved via param_to_index matches param.data view.
207+ param_slice = bucket .param_data .view (- 1 )[local_start :local_end ]
208+ torch .testing .assert_close (param_slice , param .data .view (- 1 ))
209+
210+ assert padding_observed , (
211+ "Expected padding to be applied between params. "
212+ "Ensure model dimensions are chosen such that param sizes are not multiples of 64."
213+ )
214+
215+ Utils .destroy_model_parallel ()
216+
217+
165218@pytest .mark .parametrize ("use_distributed_optimizer" , [False , True ])
166219@pytest .mark .parametrize ("overlap_grad_reduce" , [False , True ])
167220@pytest .mark .parametrize ("average_in_collective" , [False , True ])
0 commit comments