Skip to content

Commit 0299150

Browse files
feat: verify function called while converting model
1 parent c878418 commit 0299150

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/backend/patches/lora_conversions/test_flux_diffusers_lora_conversion_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import unittest.mock
12
import pytest
23
import torch
4+
import unittest
35

46

57
from invokeai.backend.patches.layers.utils import swap_shift_scale_for_linear_weight
@@ -150,5 +152,20 @@ def test_approximate_adaLN_from_state_dict_should_work(dtype: torch.dtype, rtol:
150152

151153
assert close_rate > rate
152154

155+
def test_adaLN_should_be_approximated_if_present_while_converting():
156+
"""AdaLN layer should be approximated if existed inside given model"""
157+
state_dict = keys_to_mock_state_dict(flux_diffusers_with_norm_out_state_dict_keys)
153158

159+
adaLN_layer_key = 'final_layer.adaLN_modulation.1'
160+
prefixed_layer_key = FLUX_LORA_TRANSFORMER_PREFIX + adaLN_layer_key
154161

162+
with unittest.mock.patch(
163+
'invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils.approximate_flux_adaLN_lora_layer_from_diffusers_state_dict'
164+
) as mock_approximate_func:
165+
model = lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)
166+
167+
# Check that the model has the correct number of LoRA layers.
168+
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
169+
170+
assert prefixed_layer_key in model.layers.keys()
171+
assert mock_approximate_func.call_count == 1

0 commit comments

Comments
 (0)