|
5 | 5 |
|
6 | 6 | from .utils import setup_logging |
7 | 7 |
|
| 8 | +from .device_utils import get_preferred_device |
| 9 | + |
8 | 10 | setup_logging() |
9 | 11 | import logging |
10 | 12 |
|
@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace): |
94 | 96 | deepspeed_plugin.deepspeed_config["train_batch_size"] = ( |
95 | 97 | args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) |
96 | 98 | ) |
| 99 | + |
97 | 100 | deepspeed_plugin.set_mixed_precision(args.mixed_precision) |
98 | 101 | if args.mixed_precision.lower() == "fp16": |
99 | 102 | deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. |
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): |
122 | 125 | class DeepSpeedWrapper(torch.nn.Module): |
123 | 126 | def __init__(self, **kw_models) -> None: |
124 | 127 | super().__init__() |
| 128 | + |
125 | 129 | self.models = torch.nn.ModuleDict() |
| 130 | + |
| 131 | + wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no" |
126 | 132 |
|
127 | 133 | for key, model in kw_models.items(): |
128 | 134 | if isinstance(model, list): |
129 | 135 | model = torch.nn.ModuleList(model) |
| 136 | + |
| 137 | + if wrap_model_forward_with_torch_autocast: |
| 138 | + model = self.__wrap_model_with_torch_autocast(model) |
| 139 | + |
130 | 140 | assert isinstance( |
131 | 141 | model, torch.nn.Module |
132 | 142 | ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" |
| 143 | + |
133 | 144 | self.models.update(torch.nn.ModuleDict({key: model})) |
134 | 145 |
|
| 146 | + def __wrap_model_with_torch_autocast(self, model): |
| 147 | + if isinstance(model, torch.nn.ModuleList): |
| 148 | + model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model]) |
| 149 | + else: |
| 150 | + model = self.__wrap_model_forward_with_torch_autocast(model) |
| 151 | + return model |
| 152 | + |
| 153 | + def __wrap_model_forward_with_torch_autocast(self, model): |
| 154 | + |
| 155 | + assert hasattr(model, "forward"), f"model must have a forward method." |
| 156 | + |
| 157 | + forward_fn = model.forward |
| 158 | + |
| 159 | + def forward(*args, **kwargs): |
| 160 | + try: |
| 161 | + device_type = model.device.type |
| 162 | + except AttributeError: |
| 163 | + logger.warning( |
| 164 | + "[DeepSpeed] model.device is not available. Using get_preferred_device() " |
| 165 | + "to determine the device_type for torch.autocast()." |
| 166 | + ) |
| 167 | + device_type = get_preferred_device().type |
| 168 | + |
| 169 | + with torch.autocast(device_type = device_type): |
| 170 | + return forward_fn(*args, **kwargs) |
| 171 | + |
| 172 | + model.forward = forward |
| 173 | + return model |
| 174 | + |
135 | 175 | def get_models(self): |
136 | 176 | return self.models |
| 177 | + |
137 | 178 |
|
138 | 179 | ds_model = DeepSpeedWrapper(**models) |
139 | 180 | return ds_model |
0 commit comments