diff --git a/open_diloco/hivemind_diloco.py b/open_diloco/hivemind_diloco.py index 308608b..2dec08e 100644 --- a/open_diloco/hivemind_diloco.py +++ b/open_diloco/hivemind_diloco.py @@ -164,6 +164,8 @@ def compute_and_load_pseudo_grad_into_averager(self): # opt_param is the param that will be all_reduce, it is suppose to be on cpu # main_param is the param that has been updated by the inner optimizer, it is suppose to be on gpu grad = opt_param.data - main_param.detach().to(opt_param.device) + mask = torch.rand_like(grad) > 0.95 + grad *= mask averaged_grad.copy_(grad, non_blocking=True) def notify_used_averaged_gradients(self):