Skip to content
41 changes: 41 additions & 0 deletions library/deepspeed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from .utils import setup_logging

from .device_utils import get_preferred_device

setup_logging()
import logging

Expand Down Expand Up @@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)

deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
Expand Down Expand Up @@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()

self.models = torch.nn.ModuleDict()

wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"

for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)

if wrap_model_forward_with_torch_autocast:
model = self.__wrap_model_with_torch_autocast(model)

assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"

self.models.update(torch.nn.ModuleDict({key: model}))

def __wrap_model_with_torch_autocast(self, model):
if isinstance(model, torch.nn.ModuleList):
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
else:
model = self.__wrap_model_forward_with_torch_autocast(model)
return model

def __wrap_model_forward_with_torch_autocast(self, model):

assert hasattr(model, "forward"), f"model must have a forward method."

forward_fn = model.forward

def forward(*args, **kwargs):
try:
device_type = model.device.type
except AttributeError:
logger.warning(
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
"to determine the device_type for torch.autocast()."
)
device_type = get_preferred_device().type

with torch.autocast(device_type = device_type):
return forward_fn(*args, **kwargs)

model.forward = forward
return model

def get_models(self):
return self.models


ds_model = DeepSpeedWrapper(**models)
return ds_model
5 changes: 5 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5495,6 +5495,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio


def patch_accelerator_for_fp16_training(accelerator):

from accelerate import DistributedType
if accelerator.distributed_type == DistributedType.DEEPSPEED:
return

org_unscale_grads = accelerator.scaler._unscale_grads_

def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
accelerate==0.33.0
transformers==4.44.0
diffusers[torch]==0.25.0
deepspeed==0.16.7
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.8.1.78
Expand Down