Skip to content

Commit 9bcde1b

Browse files
Gavin Zhangfacebook-github-bot
authored andcommitted
fixed issue with sharded_grads with multiple process groups (pytorch#3268)
Summary: Pull Request resolved: pytorch#3268 added a hot fix to an issue in clipping where sharded_grads was not appropriately initialized in the case with multiple process groups. Reviewed By: tsunghsienlee Differential Revision: D79853515 fbshipit-source-id: 5ad1ba34898b541cac9b01a746c29daafb1ed44f
1 parent 9a91baf commit 9bcde1b

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchrec/optim/clipping.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,9 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
149149
sharded_grads = {
150150
pgs: _get_grads(dist_params) for pgs, dist_params in sharded_params.items()
151151
}
152-
all_grads.extend(*sharded_grads.values())
152+
153+
for grads in sharded_grads.values():
154+
all_grads.extend(grads)
153155

154156
# Process replicated parameters and gradients
155157
replicate_grads = _get_grads(replicate_params)

0 commit comments

Comments
 (0)