Skip to content

Commit 7809b14

Browse files
Shagun Guptafacebook-github-bot
authored andcommitted
Bootcamp Task : Unit Tests Gradient Clipping for Dtensors (pytorch#3253)
Summary: Implemented unit tests to include cases for 2 sharded Dtensors for norm based clipping. All test cases pass. Differential Revision: D79301301
1 parent fc56fb0 commit 7809b14

File tree

1 file changed

+98
-3
lines changed

1 file changed

+98
-3
lines changed

torchrec/optim/tests/test_clipping.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,19 +245,21 @@ def test_clip_no_gradients_norm_meta_device(
245245
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
246246
@instantiate_parametrized_tests
247247
class TestGradientClippingDTensor(DTensorTestBase):
248+
"""No tests for Replicated DTensors as handled prior to GradientClippingOptimizer"""
249+
248250
def _get_params_to_pg(
249251
self, params: List[DTensor]
250252
) -> Dict[DTensor, List[ProcessGroup]]:
251253
return {param: [param.device_mesh.get_group()] for param in params}
252254

253255
@with_comms
254256
@parametrize("norm_type", ("inf", 1, 2))
255-
def test_dtensor_clip_all_gradients_norm(
257+
def test_tensor_and_sharded_dtensor_clip_all_gradients_norm(
256258
self, norm_type: Union[float, str]
257259
) -> None:
258260
"""
259261
Test to ensure that the gradient clipping optimizer clips gradients
260-
correctly with mixed DTensor and tensor by comparing gradients to its
262+
correctly with mixed sharded DTensor and tensor by comparing gradients to its
261263
torch.tensor counterpart.
262264
263265
Note that clipping for DTensor may require communication.
@@ -286,7 +288,7 @@ def test_dtensor_clip_all_gradients_norm(
286288
ref_param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
287289
ref_gradient_clipping_optimizer.step()
288290

289-
# create gradient clipping optimizer containing both DTensor and tensor
291+
# create gradient clipping optimizer containing sharded DTensor and tensor
290292
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
291293
param_1 = distribute_tensor(
292294
torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type),
@@ -336,3 +338,96 @@ def test_dtensor_clip_all_gradients_norm(
336338
ref_param.grad,
337339
f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}",
338340
)
341+
342+
@with_comms
343+
@parametrize("norm_type", ("inf", 1, 2))
344+
def test_multiple_sharded_dtensors_clip_all_gradients_norm(
345+
self, norm_type: Union[float, str]
346+
) -> None:
347+
"""
348+
Test to ensure that the gradient clipping optimizer clips gradients
349+
correctly with multiple sharded DTensors by comparing gradients to their
350+
torch.tensor counterpart.
351+
352+
Note that clipping for DTensor may require communication.
353+
"""
354+
355+
# create gradient clipping optimizer containing no dtensor for reference
356+
ref_param_1 = torch.nn.Parameter(
357+
torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
358+
)
359+
ref_param_2 = torch.nn.Parameter(
360+
torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
361+
)
362+
ref_keyed_optimizer = DummyKeyedOptimizer(
363+
{"param_1": ref_param_1, "param_2": ref_param_2},
364+
{},
365+
[{"params": [ref_param_1, ref_param_2]}],
366+
)
367+
ref_gradient_clipping_optimizer = GradientClippingOptimizer(
368+
optimizer=ref_keyed_optimizer,
369+
clipping=GradientClipping.NORM,
370+
max_gradient=10.0,
371+
norm_type=norm_type,
372+
)
373+
ref_gradient_clipping_optimizer.zero_grad()
374+
ref_param_1.grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type)
375+
ref_param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
376+
ref_gradient_clipping_optimizer.step()
377+
378+
# create gradient clipping optimizer containing 2 shareded DTensors
379+
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
380+
param_1 = distribute_tensor(
381+
torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type),
382+
device_mesh,
383+
[Shard(0)],
384+
)
385+
param_2 = distribute_tensor(
386+
torch.tensor([4.0, 5.0, 6.0], requires_grad=True, device=self.device_type),
387+
device_mesh,
388+
[Shard(0)],
389+
)
390+
param_to_pgs = self._get_params_to_pg([param_1, param_2])
391+
keyed_optimizer = DummyKeyedOptimizer(
392+
{"dtensor_param_1": param_1, "dtensor_param_2": param_2},
393+
{},
394+
[{"params": [param_1, param_2]}],
395+
)
396+
gradient_clipping_optimizer = GradientClippingOptimizer(
397+
optimizer=keyed_optimizer,
398+
clipping=GradientClipping.NORM,
399+
max_gradient=10.0,
400+
norm_type=norm_type,
401+
enable_global_grad_clip=True,
402+
param_to_pgs=param_to_pgs, # pyre-ignore[6]
403+
)
404+
gradient_clipping_optimizer.zero_grad()
405+
param_1.grad = distribute_tensor(
406+
torch.tensor([12.0, 15.0, 18.0], device=self.device_type),
407+
device_mesh,
408+
[Shard(0)],
409+
)
410+
param_2.grad = distribute_tensor(
411+
torch.tensor([20.0, 30.0, 15.0], device=self.device_type),
412+
device_mesh,
413+
[Shard(0)],
414+
)
415+
gradient_clipping_optimizer.step()
416+
417+
for param_group, ref_param_group in zip(
418+
gradient_clipping_optimizer.param_groups,
419+
ref_gradient_clipping_optimizer.param_groups,
420+
):
421+
for param, ref_param in zip(
422+
param_group["params"], ref_param_group["params"]
423+
):
424+
param_grad = (
425+
param.grad.full_tensor() # pyre-ignore[16]
426+
if isinstance(param, DTensor)
427+
else param.grad
428+
)
429+
self.assertEqual(
430+
param_grad,
431+
ref_param.grad,
432+
f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}",
433+
)

0 commit comments

Comments
 (0)