Skip to content

Commit f5d44fd

Browse files
authored
Merge pull request #2200 from kohya-ss/feat-faster-safetensors-load
feat: Speeding up loading .safetensors files
2 parents 419a9c4 + 4568631 commit f5d44fd

18 files changed

+463
-237
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/9 should be used. `requirem
1313

1414
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed` (appropriate version is not confirmed yet).
1515

16-
- [FLUX.1 training](#flux1-training)
17-
- [SD3 training](#sd3-training)
18-
1916
### Recent Updates
2017

18+
Sep 13, 2025:
19+
- The loading speed of `.safetensors` files has been improved for SD3, FLUX.1 and Lumina. See [PR #2200](https://github.com/kohya-ss/sd-scripts/pull/2200) for more details.
20+
- Model loading can be up to 1.5 times faster.
21+
- This is a wide-ranging update, so there may be bugs. Please let us know if you encounter any issues.
22+
2123
Sep 4, 2025:
2224
- The information about FLUX.1 and SD3/SD3.5 training that was described in the README has been organized and divided into the following documents:
2325
- [LoRA Training Overview](./docs/train_network.md)

flux_minimal_inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,13 +456,13 @@ def is_fp8(dt):
456456
# load clip_l (skip for chroma model)
457457
if args.model_type == "flux":
458458
logger.info(f"Loading clip_l from {args.clip_l}...")
459-
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
459+
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device, disable_mmap=True)
460460
clip_l.eval()
461461
else:
462462
clip_l = None
463463

464464
logger.info(f"Loading t5xxl from {args.t5xxl}...")
465-
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
465+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device, disable_mmap=True)
466466
t5xxl.eval()
467467

468468
# if is_fp8(clip_l_dtype):
@@ -471,7 +471,9 @@ def is_fp8(dt):
471471
# t5xxl = accelerator.prepare(t5xxl)
472472

473473
# DiT
474-
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
474+
is_schnell, model = flux_utils.load_flow_model(
475+
args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type
476+
)
475477
model.eval()
476478
logger.info(f"Casting model to {flux_dtype}")
477479
model.to(flux_dtype) # make sure model is dtype

library/custom_offloading_utils.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
11
from concurrent.futures import ThreadPoolExecutor
2+
import gc
23
import time
34
from typing import Optional, Union, Callable, Tuple
45
import torch
56
import torch.nn as nn
67

7-
from library.device_utils import clean_memory_on_device
88

9+
# Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py
10+
def _clean_memory_on_device(device: torch.device):
11+
r"""
12+
Clean memory on the specified device, will be called from training scripts.
13+
"""
14+
gc.collect()
15+
16+
# device may "cuda" or "cuda:0", so we need to check the type of device
17+
if device.type == "cuda":
18+
torch.cuda.empty_cache()
19+
if device.type == "xpu":
20+
torch.xpu.empty_cache()
21+
if device.type == "mps":
22+
torch.mps.empty_cache()
923

10-
def synchronize_device(device: torch.device):
24+
25+
def _synchronize_device(device: torch.device):
1126
if device.type == "cuda":
1227
torch.cuda.synchronize()
1328
elif device.type == "xpu":
@@ -71,19 +86,18 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
7186
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
7287
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
7388

74-
7589
# device to cpu
7690
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
7791
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
7892

79-
synchronize_device(device)
93+
_synchronize_device(device)
8094

8195
# cpu to device
8296
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
8397
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
8498
module_to_cuda.weight.data = cuda_data_view
8599

86-
synchronize_device(device)
100+
_synchronize_device(device)
87101

88102

89103
def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -152,12 +166,15 @@ def _wait_blocks_move(self, block_idx):
152166
# Gradient tensors
153167
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
154168

169+
155170
class ModelOffloader(Offloader):
156171
"""
157172
supports forward offloading
158173
"""
159174

160-
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
175+
def __init__(
176+
self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False
177+
):
161178
super().__init__(len(blocks), blocks_to_swap, device, debug)
162179

163180
# register backward hooks
@@ -172,7 +189,9 @@ def __del__(self):
172189
for handle in self.remove_handles:
173190
handle.remove()
174191

175-
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
192+
def create_backward_hook(
193+
self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int
194+
) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
176195
# -1 for 0-based index
177196
num_blocks_propagated = self.num_blocks - block_index - 1
178197
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -213,8 +232,8 @@ def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn
213232
b.to(self.device) # move block to device first
214233
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
215234

216-
synchronize_device(self.device)
217-
clean_memory_on_device(self.device)
235+
_synchronize_device(self.device)
236+
_clean_memory_on_device(self.device)
218237

219238
def wait_for_block(self, block_idx: int):
220239
if self.blocks_to_swap is None or self.blocks_to_swap == 0:

library/device_utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import functools
22
import gc
3+
from typing import Optional, Union
34

45
import torch
6+
57
try:
68
# intel gpu support for pytorch older than 2.5
79
# ipex is not needed after pytorch 2.5
@@ -36,12 +38,15 @@ def clean_memory():
3638
torch.mps.empty_cache()
3739

3840

39-
def clean_memory_on_device(device: torch.device):
41+
def clean_memory_on_device(device: Optional[Union[str, torch.device]]):
4042
r"""
4143
Clean memory on the specified device, will be called from training scripts.
4244
"""
4345
gc.collect()
44-
46+
if device is None:
47+
return
48+
if isinstance(device, str):
49+
device = torch.device(device)
4550
# device may "cuda" or "cuda:0", so we need to check the type of device
4651
if device.type == "cuda":
4752
torch.cuda.empty_cache()
@@ -51,6 +56,19 @@ def clean_memory_on_device(device: torch.device):
5156
torch.mps.empty_cache()
5257

5358

59+
def synchronize_device(device: Optional[Union[str, torch.device]]):
60+
if device is None:
61+
return
62+
if isinstance(device, str):
63+
device = torch.device(device)
64+
if device.type == "cuda":
65+
torch.cuda.synchronize()
66+
elif device.type == "xpu":
67+
torch.xpu.synchronize()
68+
elif device.type == "mps":
69+
torch.mps.synchronize()
70+
71+
5472
@functools.lru_cache(maxsize=None)
5573
def get_preferred_device() -> torch.device:
5674
r"""

library/flux_train_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
from library import flux_models, flux_utils, strategy_base, train_util
1818
from library.device_utils import init_ipex, clean_memory_on_device
19+
from library.safetensors_utils import mem_eff_save_file
1920

2021
init_ipex()
2122

22-
from .utils import setup_logging, mem_eff_save_file
23+
from .utils import setup_logging
2324

2425
setup_logging()
2526
import logging

library/flux_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
logger = logging.getLogger(__name__)
1919

2020
from library import flux_models
21-
from library.utils import load_safetensors
21+
from library.safetensors_utils import load_safetensors
2222

2323
MODEL_VERSION_FLUX_V1 = "flux1"
2424
MODEL_NAME_DEV = "dev"
@@ -124,7 +124,7 @@ def load_flow_model(
124124
logger.info(f"Loading state dict from {ckpt_path}")
125125
sd = {}
126126
for ckpt_path in ckpt_paths:
127-
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
127+
sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype))
128128

129129
# convert Diffusers to BFL
130130
if is_diffusers:

library/lumina_train_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
from library.flux_models import AutoEncoder
1919
from library.device_utils import init_ipex, clean_memory_on_device
2020
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
21+
from library.safetensors_utils import mem_eff_save_file
2122

2223
init_ipex()
2324

24-
from .utils import setup_logging, mem_eff_save_file
25+
from .utils import setup_logging
2526

2627
setup_logging()
2728
import logging

library/lumina_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from library.utils import setup_logging
1414
from library import lumina_models, flux_models
15-
from library.utils import load_safetensors
15+
from library.safetensors_utils import load_safetensors
1616
import logging
1717

1818
setup_logging()

0 commit comments

Comments
 (0)