Skip to content

Commit aeeeb5b

Browse files
fix cast fp8 (#140)
1 parent 5b06c6f commit aeeeb5b

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

diffsynth_engine/pipelines/base.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def eval(self):
229229
model.eval()
230230
return self
231231

232-
def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk:bool = False):
232+
def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk: bool = False):
233233
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload", "disable", None)
234234
if offload_mode not in valid_offload_mode:
235235
raise ValueError(f"offload_mode must be one of {valid_offload_mode}, but got {offload_mode}")
@@ -244,8 +244,7 @@ def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk:bool = Fa
244244
self._enable_sequential_cpu_offload()
245245
self.offload_to_disk = offload_to_disk
246246

247-
248-
def _enable_model_cpu_offload(self):
247+
def _enable_model_cpu_offload(self):
249248
for model_name in self.model_names:
250249
model = getattr(self, model_name)
251250
if model is not None:
@@ -258,23 +257,22 @@ def _enable_sequential_cpu_offload(self):
258257
if model is not None:
259258
enable_sequential_cpu_offload(model, self.device)
260259
self.offload_mode = "sequential_cpu_offload"
261-
260+
262261
def _disable_offload(self):
263-
self.offload_mode = None
264-
self._offload_param_dict = {}
262+
self.offload_mode = None
263+
self._offload_param_dict = {}
265264
for model_name in self.model_names:
266265
model = getattr(self, model_name)
267266
if model is not None:
268267
model.to(self.device)
269268

270-
271269
def enable_fp8_autocast(
272270
self, model_names: List[str], compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False
273271
):
274272
for model_name in model_names:
275273
model = getattr(self, model_name)
276274
if model is not None:
277-
model.to(device=self.device, dtype=torch.float8_e4m3fn)
275+
model.to(dtype=torch.float8_e4m3fn)
278276
enable_fp8_autocast(model, compute_dtype, use_fp8_linear)
279277
self.fp8_autocast_enabled = True
280278

@@ -298,15 +296,17 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
298296
for model_name in load_model_names:
299297
model = getattr(self, model_name)
300298
if model is None:
301-
raise ValueError(f"model {model_name} is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True")
299+
raise ValueError(
300+
f"model {model_name} is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True"
301+
)
302302
if model is not None and (p := next(model.parameters(), None)) is not None and p.device.type != self.device:
303303
model.to(self.device)
304304
# fresh the cuda cache
305305
empty_cache()
306306

307307
def model_lifecycle_finish(self, model_names: List[str] | None = None):
308308
if not self.offload_to_disk or self.offload_mode is None:
309-
return
309+
return
310310
for model_name in model_names:
311311
model = getattr(self, model_name)
312312
del model
@@ -316,7 +316,6 @@ def model_lifecycle_finish(self, model_names: List[str] | None = None):
316316
print(f"model {model_name} has been deleted from memory")
317317
logger.info(f"model {model_name} has been deleted from memory")
318318
empty_cache()
319-
320-
319+
321320
def compile(self):
322321
raise NotImplementedError(f"{self.__class__.__name__} does not support compile")

0 commit comments

Comments
 (0)