Skip to content

Commit 4f4e752

Browse files
Merge branch 'master' into zenflow_zero3
2 parents 47b10d8 + 6ea345a commit 4f4e752

File tree

2 files changed

+12
-50
lines changed

2 files changed

+12
-50
lines changed

deepspeed/runtime/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3116,7 +3116,7 @@ def load_checkpoint(self,
31163116
load_module_only=load_module_only,
31173117
custom_load_fn=custom_load_fn)
31183118

3119-
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
3119+
load_zero_checkpoint = load_path is not None and self.zero_optimization()
31203120
if load_zero_checkpoint:
31213121
if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint():
31223122
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)

deepspeed/runtime/zero/stage3.py

Lines changed: 11 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
1818
from deepspeed.runtime.base_optimizer import ZeROOptimizer
1919
from deepspeed.utils import logger
20-
from deepspeed.utils.torch import register_grad_hook
20+
from deepspeed.utils.torch import register_grad_hook, required_torch_version
2121
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
2222
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
2323
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
@@ -1251,59 +1251,21 @@ def reduce_partition_and_remove_grads(*notneeded):
12511251
# Partition the parameter after creating the hook
12521252
param.partition()
12531253

1254-
# We delay reduce-scatter for all gradients in the leaf modules until the backward pass of the leaf module is done
1254+
# We delay reduce for all gradients in the leaf modules until the backward pass of the leaf module is done
12551255
for leaf_module, leaf_parameters in self.leaf_parameters.items():
12561256

1257-
def wrapper_pre_hook(params):
1257+
def make_hook(params):
12581258

1259-
def forward_pre_hook(module, input):
1260-
"""Pre-forward hook to set backward hook on input tensors to the leaf module"""
1261-
module._leaf_module_inputs_remaining = 0
1259+
def reduce_leaf_module_grads(module, grad_input, grad_output):
1260+
for param in params:
1261+
if param.grad is None:
1262+
param.grad = torch.zeros_like(param)
1263+
self.reduce_ready_partitions_and_remove_grads(param)
12621264

1263-
@instrument_w_nvtx
1264-
def reduce_leaf_module_grads(grad):
1265-
module._leaf_module_inputs_remaining -= 1
1266-
# Make sure everything is done in the leaf module
1267-
if module._leaf_module_inputs_remaining == 0:
1268-
for param in params:
1269-
if param.grad is None:
1270-
param.grad = torch.zeros_like(param)
1271-
self.reduce_ready_partitions_and_remove_grads(param)
1265+
return reduce_leaf_module_grads
12721266

1273-
def set_module_bwd_hook(tensor):
1274-
if tensor.requires_grad:
1275-
module._leaf_module_inputs_remaining += 1
1276-
tensor.register_hook(reduce_leaf_module_grads)
1277-
return tensor
1278-
1279-
output = apply_to_tensors_only(set_module_bwd_hook, input)
1280-
1281-
return output
1282-
1283-
return forward_pre_hook
1284-
1285-
def wrapper_post_hook():
1286-
1287-
def forward_post_hook(module, input, output):
1288-
"""Pre-forward hook to set backward hook on input tensors to the leaf module"""
1289-
module._leaf_output_required_grad_num = 0
1290-
1291-
def increment_rg_count_bwd_hook(tensor):
1292-
if tensor.requires_grad:
1293-
module._leaf_output_required_grad_num += 1
1294-
return tensor
1295-
1296-
apply_to_tensors_only(increment_rg_count_bwd_hook, output)
1297-
1298-
if module._leaf_module_inputs_remaining == 0 and module._leaf_output_required_grad_num > 0:
1299-
raise RuntimeError(
1300-
"A module cannot be set as a leaf module when it does not have any input tensors that require gradients and has output tensors that require gradients. This is because the gradient reduction hook will not be called in this case."
1301-
)
1302-
1303-
return forward_post_hook
1304-
1305-
self._leaf_module_hooks.append(leaf_module.register_forward_pre_hook(wrapper_pre_hook(leaf_parameters)))
1306-
self._leaf_module_hooks.append(leaf_module.register_forward_hook(wrapper_post_hook()))
1267+
assert required_torch_version(min_version=1.8), "Leaf module requires PyTorch >= 1.8"
1268+
self._leaf_module_hooks.append(leaf_module.register_full_backward_hook(make_hook(leaf_parameters)))
13071269

13081270
print_rank_0('[End] Create gradient reduction hooks')
13091271

0 commit comments

Comments
 (0)