Skip to content

Commit 591655e

Browse files
authored
Merge branch 'main' into speedup-model-loading
2 parents 4c81c96 + 1c6ab9e commit 591655e

File tree

5 files changed

+55
-17
lines changed

5 files changed

+55
-17
lines changed

examples/server/requirements.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch~=2.4.0
1+
torch~=2.7.0
22
transformers==4.46.1
33
sentencepiece
44
aiohttp

examples/server/requirements.txt

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,36 +63,42 @@ networkx==3.2.1
6363
# via torch
6464
numpy==2.0.2
6565
# via transformers
66-
nvidia-cublas-cu12==12.1.3.1
66+
nvidia-cublas-cu12==12.6.4.1
6767
# via
6868
# nvidia-cudnn-cu12
6969
# nvidia-cusolver-cu12
7070
# torch
71-
nvidia-cuda-cupti-cu12==12.1.105
71+
nvidia-cuda-cupti-cu12==12.6.80
7272
# via torch
73-
nvidia-cuda-nvrtc-cu12==12.1.105
73+
nvidia-cuda-nvrtc-cu12==12.6.77
7474
# via torch
75-
nvidia-cuda-runtime-cu12==12.1.105
75+
nvidia-cuda-runtime-cu12==12.6.77
7676
# via torch
77-
nvidia-cudnn-cu12==9.1.0.70
77+
nvidia-cudnn-cu12==9.5.1.17
7878
# via torch
79-
nvidia-cufft-cu12==11.0.2.54
79+
nvidia-cufft-cu12==11.3.0.4
8080
# via torch
81-
nvidia-curand-cu12==10.3.2.106
81+
nvidia-cufile-cu12==1.11.1.6
8282
# via torch
83-
nvidia-cusolver-cu12==11.4.5.107
83+
nvidia-curand-cu12==10.3.7.77
8484
# via torch
85-
nvidia-cusparse-cu12==12.1.0.106
85+
nvidia-cusolver-cu12==11.7.1.2
86+
# via torch
87+
nvidia-cusparse-cu12==12.5.4.2
8688
# via
8789
# nvidia-cusolver-cu12
8890
# torch
89-
nvidia-nccl-cu12==2.20.5
91+
nvidia-cusparselt-cu12==0.6.3
92+
# via torch
93+
nvidia-nccl-cu12==2.26.2
9094
# via torch
91-
nvidia-nvjitlink-cu12==12.9.86
95+
nvidia-nvjitlink-cu12==12.6.85
9296
# via
97+
# nvidia-cufft-cu12
9398
# nvidia-cusolver-cu12
9499
# nvidia-cusparse-cu12
95-
nvidia-nvtx-cu12==12.1.105
100+
# torch
101+
nvidia-nvtx-cu12==12.6.77
96102
# via torch
97103
packaging==24.1
98104
# via
@@ -137,20 +143,19 @@ sympy==1.13.3
137143
# via torch
138144
tokenizers==0.20.1
139145
# via transformers
140-
torch==2.4.1
146+
torch==2.7.0
141147
# via -r requirements.in
142148
tqdm==4.66.5
143149
# via
144150
# huggingface-hub
145151
# transformers
146152
transformers==4.46.1
147153
# via -r requirements.in
148-
triton==3.0.0
154+
triton==3.3.0
149155
# via torch
150156
typing-extensions==4.12.2
151157
# via
152158
# anyio
153-
# exceptiongroup
154159
# fastapi
155160
# huggingface-hub
156161
# multidict

src/diffusers/loaders/lora_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def _func_optionally_disable_offloading(_pipeline):
470470
for _, component in _pipeline.components.items():
471471
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
472472
continue
473-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
473+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
474474

475475
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
476476

src/diffusers/utils/torch_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def get_device():
175175
return "npu"
176176
elif hasattr(torch, "xpu") and torch.xpu.is_available():
177177
return "xpu"
178+
elif torch.backends.mps.is_available():
179+
return "mps"
178180
else:
179181
return "cpu"
180182

tests/lora/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,3 +2510,34 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
25102510
# materializes the test methods on invocation which cannot be overridden.
25112511
return
25122512
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
2513+
2514+
@require_torch_accelerator
2515+
def test_lora_loading_model_cpu_offload(self):
2516+
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
2517+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2518+
pipe = self.pipeline_class(**components)
2519+
pipe = pipe.to(torch_device)
2520+
pipe.set_progress_bar_config(disable=None)
2521+
2522+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2523+
denoiser.add_adapter(denoiser_lora_config)
2524+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2525+
2526+
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2527+
2528+
with tempfile.TemporaryDirectory() as tmpdirname:
2529+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2530+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2531+
self.pipeline_class.save_lora_weights(
2532+
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
2533+
)
2534+
# reinitialize the pipeline to mimic the inference workflow.
2535+
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
2536+
pipe = self.pipeline_class(**components)
2537+
pipe.enable_model_cpu_offload(device=torch_device)
2538+
pipe.load_lora_weights(tmpdirname)
2539+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2540+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2541+
2542+
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
2543+
self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))

0 commit comments

Comments
 (0)