|
17 | 17 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
18 | 18 | from deepspeed.runtime.base_optimizer import ZeROOptimizer |
19 | 19 | from deepspeed.utils import logger |
20 | | -from deepspeed.utils.torch import register_grad_hook |
| 20 | +from deepspeed.utils.torch import register_grad_hook, required_torch_version |
21 | 21 | from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler |
22 | 22 | from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes |
23 | 23 | from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce |
@@ -1251,59 +1251,21 @@ def reduce_partition_and_remove_grads(*notneeded): |
1251 | 1251 | # Partition the parameter after creating the hook |
1252 | 1252 | param.partition() |
1253 | 1253 |
|
1254 | | - # We delay reduce-scatter for all gradients in the leaf modules until the backward pass of the leaf module is done |
| 1254 | + # We delay reduce for all gradients in the leaf modules until the backward pass of the leaf module is done |
1255 | 1255 | for leaf_module, leaf_parameters in self.leaf_parameters.items(): |
1256 | 1256 |
|
1257 | | - def wrapper_pre_hook(params): |
| 1257 | + def make_hook(params): |
1258 | 1258 |
|
1259 | | - def forward_pre_hook(module, input): |
1260 | | - """Pre-forward hook to set backward hook on input tensors to the leaf module""" |
1261 | | - module._leaf_module_inputs_remaining = 0 |
| 1259 | + def reduce_leaf_module_grads(module, grad_input, grad_output): |
| 1260 | + for param in params: |
| 1261 | + if param.grad is None: |
| 1262 | + param.grad = torch.zeros_like(param) |
| 1263 | + self.reduce_ready_partitions_and_remove_grads(param) |
1262 | 1264 |
|
1263 | | - @instrument_w_nvtx |
1264 | | - def reduce_leaf_module_grads(grad): |
1265 | | - module._leaf_module_inputs_remaining -= 1 |
1266 | | - # Make sure everything is done in the leaf module |
1267 | | - if module._leaf_module_inputs_remaining == 0: |
1268 | | - for param in params: |
1269 | | - if param.grad is None: |
1270 | | - param.grad = torch.zeros_like(param) |
1271 | | - self.reduce_ready_partitions_and_remove_grads(param) |
| 1265 | + return reduce_leaf_module_grads |
1272 | 1266 |
|
1273 | | - def set_module_bwd_hook(tensor): |
1274 | | - if tensor.requires_grad: |
1275 | | - module._leaf_module_inputs_remaining += 1 |
1276 | | - tensor.register_hook(reduce_leaf_module_grads) |
1277 | | - return tensor |
1278 | | - |
1279 | | - output = apply_to_tensors_only(set_module_bwd_hook, input) |
1280 | | - |
1281 | | - return output |
1282 | | - |
1283 | | - return forward_pre_hook |
1284 | | - |
1285 | | - def wrapper_post_hook(): |
1286 | | - |
1287 | | - def forward_post_hook(module, input, output): |
1288 | | - """Pre-forward hook to set backward hook on input tensors to the leaf module""" |
1289 | | - module._leaf_output_required_grad_num = 0 |
1290 | | - |
1291 | | - def increment_rg_count_bwd_hook(tensor): |
1292 | | - if tensor.requires_grad: |
1293 | | - module._leaf_output_required_grad_num += 1 |
1294 | | - return tensor |
1295 | | - |
1296 | | - apply_to_tensors_only(increment_rg_count_bwd_hook, output) |
1297 | | - |
1298 | | - if module._leaf_module_inputs_remaining == 0 and module._leaf_output_required_grad_num > 0: |
1299 | | - raise RuntimeError( |
1300 | | - "A module cannot be set as a leaf module when it does not have any input tensors that require gradients and has output tensors that require gradients. This is because the gradient reduction hook will not be called in this case." |
1301 | | - ) |
1302 | | - |
1303 | | - return forward_post_hook |
1304 | | - |
1305 | | - self._leaf_module_hooks.append(leaf_module.register_forward_pre_hook(wrapper_pre_hook(leaf_parameters))) |
1306 | | - self._leaf_module_hooks.append(leaf_module.register_forward_hook(wrapper_post_hook())) |
| 1267 | + assert required_torch_version(min_version=1.8), "Leaf module requires PyTorch >= 1.8" |
| 1268 | + self._leaf_module_hooks.append(leaf_module.register_full_backward_hook(make_hook(leaf_parameters))) |
1307 | 1269 |
|
1308 | 1270 | print_rank_0('[End] Create gradient reduction hooks') |
1309 | 1271 |
|
|
0 commit comments