diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 7e78a6b060fb..33da34af1516 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -24,7 +24,7 @@ from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing from deepspeed.accelerator import get_accelerator from ..module_inject.policy import TransformerPolicy -from ..module_inject.auto_tp import AutoTP +from ..module_inject.auto_tp import AutoTP, Loading from ..module_inject.replace_policy import generic_policies from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor, get_alibi_mask @@ -363,7 +363,14 @@ def load_module_recursive(module, prefix='', level=0): child = Normalize(dim=child.weight.ds_shape[-1], dtype=child.weight.dtype, eps=child.eps) setattr(module, name, child) load(child, self.sd, prefix + name + '.') + # Load buffers for this module + if len(child._buffers) != 0: + Loading.load_buffer(child, self.sd, checking_key, self.mp_group) else: + checking_key = prefix + name + '.' + # Load buffers for non-policy modules + if len(child._buffers) != 0: + Loading.load_buffer(child, self.sd, checking_key, self.mp_group) load_module_recursive(child, prefix if level == 0 else prefix + name + '.', level + 1) load_module_recursive(r_module) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 82cd9042071e..749bc57bc623 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -142,14 +142,22 @@ def is_load_module(module): ] return module.__class__ in load_layers or module._get_name() in load_layer_names - def load_buffer(module, state_dict, prefix): + def load_buffer(module, state_dict, prefix, mp_group=None): for name in module._buffers.keys(): if module._buffers[name].data.is_meta: module._buffers[name] = torch.nn.parameter.Parameter( data=torch.empty_like(module._buffers[name].data, device="cpu"), requires_grad=module._buffers[name].data.requires_grad) if prefix + name in state_dict.keys(): - module._buffers[name].data.copy_(state_dict[prefix + name]) + # Buffers are typically not sharded across devices, so we copy the full buffer + # to all devices. Ensure the buffer data is moved to the correct device. + buffer_data = state_dict[prefix + name] + if not buffer_data.is_meta: + # Move buffer data to the same device as the module's buffer + target_device = module._buffers[name].data.device + if buffer_data.device != target_device: + buffer_data = buffer_data.to(target_device) + module._buffers[name].data.copy_(buffer_data) def load(module, state_dict, prefix, mp_group=None): mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) @@ -461,7 +469,7 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''): else: continue if len(child._buffers) != 0 and self.state_dict is not None: - Loading.load_buffer(child, self.state_dict, checking_key) + Loading.load_buffer(child, self.state_dict, checking_key, self.mp_group) if child.__class__ in self.linear_policies: setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name, self.conv_linear_layer)) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 26752cfa4fec..aa5feccca480 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -709,7 +709,7 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di else: continue if len(child._buffers) != 0 and state_dict is not None: - Loading.load_buffer(child, state_dict, checking_key) + Loading.load_buffer(child, state_dict, checking_key, mp_group=None) _, layer_id = _replace_module(child, policies, prefix if level_id == 0 and skip_level_0_prefix(model, state_dict) else \