Skip to content

Commit b713511

Browse files
committed
check layer dtypes in lora test
1 parent 245137f commit b713511

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

tests/lora/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
import inspect
1616
import os
17+
import re
1718
import tempfile
1819
import unittest
1920
from itertools import product
@@ -2100,6 +2101,23 @@ def test_correct_lora_configs_with_different_ranks(self):
21002101
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
21012102

21022103
def test_layerwise_upcasting_inference_denoiser(self):
2104+
from diffusers.hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
2105+
2106+
def check_linear_dtype(module, storage_dtype, compute_dtype):
2107+
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
2108+
if getattr(module, "_precision_sensitive_module_patterns", None) is not None:
2109+
patterns_to_check += tuple(module._precision_sensitive_module_patterns)
2110+
for name, submodule in module.named_modules():
2111+
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
2112+
continue
2113+
dtype_to_check = storage_dtype
2114+
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
2115+
dtype_to_check = compute_dtype
2116+
if getattr(submodule, "weight", None) is not None:
2117+
self.assertEqual(submodule.weight.dtype, dtype_to_check)
2118+
if getattr(submodule, "bias", None) is not None:
2119+
self.assertEqual(submodule.bias.dtype, dtype_to_check)
2120+
21032121
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21042122
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
21052123
pipe = self.pipeline_class(**components)
@@ -2125,6 +2143,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21252143

21262144
if storage_dtype is not None:
21272145
denoiser.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
2146+
check_linear_dtype(denoiser, storage_dtype, compute_dtype)
21282147

21292148
return pipe
21302149

tests/models/test_modeling_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
require_torch_2,
6363
require_torch_accelerator,
6464
require_torch_accelerator_with_training,
65+
require_torch_gpu,
6566
require_torch_multi_gpu,
6667
run_test_in_subprocess,
6768
torch_all_close,

0 commit comments

Comments
 (0)