Skip to content

Commit 2a8e0fe

Browse files
authored
Merge branch 'main' into tests/fix-failing-float16-cuda
2 parents e88ae2f + f3e1310 commit 2a8e0fe

File tree

5 files changed

+62
-8
lines changed

5 files changed

+62
-8
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ pipeline.load_lora_weights(
315315
> [!TIP]
316316
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
317317
318+
If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
319+
318320
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
319321

320322
## Merge

tests/models/test_modeling_common.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,6 @@ def test_model_parallelism(self):
13501350
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
13511351
# Making sure part of the model will actually end up offloaded
13521352
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
1353-
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
13541353

13551354
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
13561355

@@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin:
20192018
20202019
"""
20212020

2021+
different_shapes_for_compilation = None
2022+
20222023
def tearDown(self):
20232024
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
20242025
# there will be recompilation errors, as torch caches the model when run in the same process.
@@ -2056,11 +2057,13 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
20562057
- hotswap the second adapter
20572058
- check that the outputs are correct
20582059
- optionally compile the model
2060+
- optionally check if recompilations happen on different shapes
20592061
20602062
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
20612063
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
20622064
fine.
20632065
"""
2066+
different_shapes = self.different_shapes_for_compilation
20642067
# create 2 adapters with different ranks and alphas
20652068
torch.manual_seed(0)
20662069
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -2110,19 +2113,30 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
21102113
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
21112114

21122115
if do_compile:
2113-
model = torch.compile(model, mode="reduce-overhead")
2116+
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
21142117

21152118
with torch.inference_mode():
2116-
output0_after = model(**inputs_dict)["sample"]
2117-
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
2119+
# additionally check if dynamic compilation works.
2120+
if different_shapes is not None:
2121+
for height, width in different_shapes:
2122+
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
2123+
_ = model(**new_inputs_dict)
2124+
else:
2125+
output0_after = model(**inputs_dict)["sample"]
2126+
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
21182127

21192128
# hotswap the 2nd adapter
21202129
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
21212130

21222131
# we need to call forward to potentially trigger recompilation
21232132
with torch.inference_mode():
2124-
output1_after = model(**inputs_dict)["sample"]
2125-
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
2133+
if different_shapes is not None:
2134+
for height, width in different_shapes:
2135+
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
2136+
_ = model(**new_inputs_dict)
2137+
else:
2138+
output1_after = model(**inputs_dict)["sample"]
2139+
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
21262140

21272141
# check error when not passing valid adapter name
21282142
name = "does-not-exist"
@@ -2240,3 +2254,23 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self):
22402254
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
22412255
)
22422256
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
2257+
2258+
@parameterized.expand([(11, 11), (7, 13), (13, 7)])
2259+
@require_torch_version_greater("2.7.1")
2260+
def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
2261+
different_shapes_for_compilation = self.different_shapes_for_compilation
2262+
if different_shapes_for_compilation is None:
2263+
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
2264+
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
2265+
# variable to represent input sizes that are the same. For more details,
2266+
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
2267+
torch.fx.experimental._config.use_duck_shape = False
2268+
2269+
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
2270+
with torch._dynamo.config.patch(error_on_recompile=True):
2271+
self.check_model_hotswap(
2272+
do_compile=True,
2273+
rank0=rank0,
2274+
rank1=rank1,
2275+
target_modules0=target_modules,
2276+
)

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def prepare_dummy_input(self, height, width):
186186

187187
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
188188
model_class = FluxTransformer2DModel
189+
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
189190

190191
def prepare_init_args_and_inputs_for_common(self):
191192
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
193+
194+
def prepare_dummy_input(self, height, width):
195+
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)

tests/quantization/bnb/test_4bit.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,14 @@ class Base4bitTests(unittest.TestCase):
9898

9999
@classmethod
100100
def setUpClass(cls):
101-
torch.use_deterministic_algorithms(True)
101+
cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
102+
if not cls.is_deterministic_enabled:
103+
torch.use_deterministic_algorithms(True)
104+
105+
@classmethod
106+
def tearDownClass(cls):
107+
if not cls.is_deterministic_enabled:
108+
torch.use_deterministic_algorithms(False)
102109

103110
def get_dummy_inputs(self):
104111
prompt_embeds = load_pt(

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,14 @@ class Base8bitTests(unittest.TestCase):
9999

100100
@classmethod
101101
def setUpClass(cls):
102-
torch.use_deterministic_algorithms(True)
102+
cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
103+
if not cls.is_deterministic_enabled:
104+
torch.use_deterministic_algorithms(True)
105+
106+
@classmethod
107+
def tearDownClass(cls):
108+
if not cls.is_deterministic_enabled:
109+
torch.use_deterministic_algorithms(False)
103110

104111
def get_dummy_inputs(self):
105112
prompt_embeds = load_pt(

0 commit comments

Comments
 (0)