Skip to content

Commit 7c075a9

Browse files
authored
Merge pull request #2060 from saibit-tech/sd3
Fix: try aligning dtype of matrixes when training with deepspeed and mixed-precision is set to bf16 or fp16
2 parents 64430eb + 1684aba commit 7c075a9

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

library/deepspeed_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from .utils import setup_logging
77

8+
from .device_utils import get_preferred_device
9+
810
setup_logging()
911
import logging
1012

@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
9496
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
9597
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
9698
)
99+
97100
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
98101
if args.mixed_precision.lower() == "fp16":
99102
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
122125
class DeepSpeedWrapper(torch.nn.Module):
123126
def __init__(self, **kw_models) -> None:
124127
super().__init__()
128+
125129
self.models = torch.nn.ModuleDict()
130+
131+
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
126132

127133
for key, model in kw_models.items():
128134
if isinstance(model, list):
129135
model = torch.nn.ModuleList(model)
136+
137+
if wrap_model_forward_with_torch_autocast:
138+
model = self.__wrap_model_with_torch_autocast(model)
139+
130140
assert isinstance(
131141
model, torch.nn.Module
132142
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
143+
133144
self.models.update(torch.nn.ModuleDict({key: model}))
134145

146+
def __wrap_model_with_torch_autocast(self, model):
147+
if isinstance(model, torch.nn.ModuleList):
148+
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
149+
else:
150+
model = self.__wrap_model_forward_with_torch_autocast(model)
151+
return model
152+
153+
def __wrap_model_forward_with_torch_autocast(self, model):
154+
155+
assert hasattr(model, "forward"), f"model must have a forward method."
156+
157+
forward_fn = model.forward
158+
159+
def forward(*args, **kwargs):
160+
try:
161+
device_type = model.device.type
162+
except AttributeError:
163+
logger.warning(
164+
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
165+
"to determine the device_type for torch.autocast()."
166+
)
167+
device_type = get_preferred_device().type
168+
169+
with torch.autocast(device_type = device_type):
170+
return forward_fn(*args, **kwargs)
171+
172+
model.forward = forward
173+
return model
174+
135175
def get_models(self):
136176
return self.models
177+
137178

138179
ds_model = DeepSpeedWrapper(**models)
139180
return ds_model

library/train_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5498,6 +5498,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
54985498

54995499

55005500
def patch_accelerator_for_fp16_training(accelerator):
5501+
5502+
from accelerate import DistributedType
5503+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
5504+
return
5505+
55015506
org_unscale_grads = accelerator.scaler._unscale_grads_
55025507

55035508
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):

0 commit comments

Comments
 (0)