Skip to content

Commit a16b555

Browse files
committed
Simplify flux model dtype conversion in model loader
1 parent 6667c39 commit a16b555

File tree

2 files changed

+4
-29
lines changed

2 files changed

+4
-29
lines changed

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
22
"""Class for Flux model loading in InvokeAI."""
33

4-
import gc
54
from pathlib import Path
65
from typing import Optional
76

@@ -35,7 +34,6 @@
3534
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
3635
from invokeai.backend.model_manager.util.model_util import (
3736
convert_bundle_to_flux_transformer_checkpoint,
38-
convert_sd_entry_to_bfloat16,
3937
)
4038
from invokeai.backend.util.silence_warnings import SilenceWarnings
4139

@@ -197,30 +195,12 @@ def _load_from_singlefile(
197195
sd = load_file(model_path)
198196
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
199197
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
200-
futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = []
201-
cache_updated = False
198+
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
199+
self._ram_cache.make_room(new_sd_size)
202200
for k in sd.keys():
203-
v = sd[k]
204-
if v.dtype != torch.bfloat16:
205-
if not cache_updated:
206-
# For the first iteration we are just requesting the current size of the state dict
207-
# This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16
208-
# This should be refined in the future if not removed entirely when we support more data types
209-
sd_size = sum([ten.nelement() * ten.element_size() for ten in sd.values()])
210-
self._ram_cache.make_room(sd_size)
211-
cache_updated = True
212-
futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v))
213-
# Clean up unused variables
214-
del v
215-
gc.collect() # Force garbage collection to free memory
216-
for future in futures:
217-
k, v = torch.jit.wait(future)
218-
sd[k] = v
219-
del k, v
220-
del futures
221-
gc.collect() # Force garbage collection to free memory
201+
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
202+
sd[k] = sd[k].to(torch.bfloat16)
222203
model.load_state_dict(sd, assign=True)
223-
224204
return model
225205

226206

invokeai/backend/model_manager/util/model_util.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,3 @@ def convert_bundle_to_flux_transformer_checkpoint(
159159
del transformer_state_dict[k]
160160

161161
return original_state_dict
162-
163-
164-
@torch.jit.script
165-
def convert_sd_entry_to_bfloat16(key: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
166-
return key, tensor.to(torch.bfloat16, copy=False)

0 commit comments

Comments
 (0)