|
1 | 1 | # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team |
2 | 2 | """Class for Flux model loading in InvokeAI.""" |
3 | 3 |
|
4 | | -import gc |
5 | 4 | from pathlib import Path |
6 | 5 | from typing import Optional |
7 | 6 |
|
|
35 | 34 | from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry |
36 | 35 | from invokeai.backend.model_manager.util.model_util import ( |
37 | 36 | convert_bundle_to_flux_transformer_checkpoint, |
38 | | - convert_sd_entry_to_bfloat16, |
39 | 37 | ) |
40 | 38 | from invokeai.backend.util.silence_warnings import SilenceWarnings |
41 | 39 |
|
@@ -197,30 +195,12 @@ def _load_from_singlefile( |
197 | 195 | sd = load_file(model_path) |
198 | 196 | if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: |
199 | 197 | sd = convert_bundle_to_flux_transformer_checkpoint(sd) |
200 | | - futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = [] |
201 | | - cache_updated = False |
| 198 | + new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) |
| 199 | + self._ram_cache.make_room(new_sd_size) |
202 | 200 | for k in sd.keys(): |
203 | | - v = sd[k] |
204 | | - if v.dtype != torch.bfloat16: |
205 | | - if not cache_updated: |
206 | | - # For the first iteration we are just requesting the current size of the state dict |
207 | | - # This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16 |
208 | | - # This should be refined in the future if not removed entirely when we support more data types |
209 | | - sd_size = sum([ten.nelement() * ten.element_size() for ten in sd.values()]) |
210 | | - self._ram_cache.make_room(sd_size) |
211 | | - cache_updated = True |
212 | | - futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v)) |
213 | | - # Clean up unused variables |
214 | | - del v |
215 | | - gc.collect() # Force garbage collection to free memory |
216 | | - for future in futures: |
217 | | - k, v = torch.jit.wait(future) |
218 | | - sd[k] = v |
219 | | - del k, v |
220 | | - del futures |
221 | | - gc.collect() # Force garbage collection to free memory |
| 201 | + # We need to cast to bfloat16 due to it being the only currently supported dtype for inference |
| 202 | + sd[k] = sd[k].to(torch.bfloat16) |
222 | 203 | model.load_state_dict(sd, assign=True) |
223 | | - |
224 | 204 | return model |
225 | 205 |
|
226 | 206 |
|
|
0 commit comments