Skip to content

Commit 1335d7e

Browse files
committed
add simple lora inference tests
1 parent 77a32a7 commit 1335d7e

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/lora/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,3 +2098,43 @@ def test_correct_lora_configs_with_different_ranks(self):
20982098
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
20992099
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
21002100
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
2101+
2102+
def test_layerwise_upcasting_inference_denoiser(self):
2103+
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
2104+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
2105+
pipe = self.pipeline_class(**components)
2106+
pipe = pipe.to(torch_device, dtype=compute_dtype)
2107+
pipe.set_progress_bar_config(disable=None)
2108+
2109+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
2110+
pipe.text_encoder.add_adapter(text_lora_config)
2111+
self.assertTrue(
2112+
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
2113+
)
2114+
2115+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2116+
denoiser.add_adapter(denoiser_lora_config)
2117+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2118+
2119+
if self.has_two_text_encoders or self.has_three_text_encoders:
2120+
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
2121+
pipe.text_encoder_2.add_adapter(text_lora_config)
2122+
self.assertTrue(
2123+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
2124+
)
2125+
2126+
if storage_dtype is not None:
2127+
denoiser.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
2128+
2129+
return pipe
2130+
2131+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2132+
2133+
pipe_fp32 = initialize_pipeline(storage_dtype=None)
2134+
pipe_fp32(**inputs, generator=torch.manual_seed(0))[0]
2135+
2136+
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
2137+
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
2138+
2139+
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
2140+
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]

0 commit comments

Comments
 (0)