Skip to content

Commit 12dbb40

Browse files
committed
Merge branch 'fixes-issue-10872' of https://github.com/ishan-modi/diffusers into fixes-issue-10872
2 parents e3046a5 + 0059947 commit 12dbb40

File tree

4 files changed

+104
-2
lines changed

4 files changed

+104
-2
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,3 +1276,74 @@ def remap_single_transformer_blocks_(key, state_dict):
12761276
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
12771277

12781278
return converted_state_dict
1279+
1280+
1281+
def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
1282+
# Remove "diffusion_model." prefix from keys.
1283+
state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1284+
converted_state_dict = {}
1285+
1286+
def get_num_layers(keys, pattern):
1287+
layers = set()
1288+
for key in keys:
1289+
match = re.search(pattern, key)
1290+
if match:
1291+
layers.add(int(match.group(1)))
1292+
return len(layers)
1293+
1294+
def process_block(prefix, index, convert_norm):
1295+
# Process attention qkv: pop lora_A and lora_B weights.
1296+
lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
1297+
lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
1298+
for attn_key in ["to_q", "to_k", "to_v"]:
1299+
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
1300+
for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
1301+
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
1302+
1303+
# Process attention out weights.
1304+
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
1305+
f"{prefix}.{index}.attention.out.lora_A.weight"
1306+
)
1307+
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
1308+
f"{prefix}.{index}.attention.out.lora_B.weight"
1309+
)
1310+
1311+
# Process feed-forward weights for layers 1, 2, and 3.
1312+
for layer in range(1, 4):
1313+
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
1314+
f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
1315+
)
1316+
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
1317+
f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
1318+
)
1319+
1320+
if convert_norm:
1321+
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
1322+
f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
1323+
)
1324+
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
1325+
f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
1326+
)
1327+
1328+
noise_refiner_pattern = r"noise_refiner\.(\d+)\."
1329+
num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
1330+
for i in range(num_noise_refiner_layers):
1331+
process_block("noise_refiner", i, convert_norm=True)
1332+
1333+
context_refiner_pattern = r"context_refiner\.(\d+)\."
1334+
num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
1335+
for i in range(num_context_refiner_layers):
1336+
process_block("context_refiner", i, convert_norm=False)
1337+
1338+
core_transformer_pattern = r"layers\.(\d+)\."
1339+
num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
1340+
for i in range(num_core_transformer_layers):
1341+
process_block("layers", i, convert_norm=True)
1342+
1343+
if len(state_dict) > 0:
1344+
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
1345+
1346+
for key in list(converted_state_dict.keys()):
1347+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1348+
1349+
return converted_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_convert_hunyuan_video_lora_to_diffusers,
4242
_convert_kohya_flux_lora_to_diffusers,
4343
_convert_non_diffusers_lora_to_diffusers,
44+
_convert_non_diffusers_lumina2_lora_to_diffusers,
4445
_convert_xlabs_flux_lora_to_diffusers,
4546
_maybe_map_sgm_blocks_to_diffusers,
4647
)
@@ -3815,7 +3816,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
38153816

38163817
@classmethod
38173818
@validate_hf_hub_args
3818-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
38193819
def lora_state_dict(
38203820
cls,
38213821
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -3909,6 +3909,11 @@ def lora_state_dict(
39093909
logger.warning(warn_msg)
39103910
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
39113911

3912+
# conversion.
3913+
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
3914+
if non_diffusers:
3915+
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
3916+
39123917
return state_dict
39133918

39143919
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights

src/diffusers/quantizers/quantization_config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ class QuantizationMethod(str, Enum):
4747
TORCHAO = "torchao"
4848

4949

50+
if is_torchao_available():
51+
from torchao.quantization.quant_primitives import MappingType
52+
53+
class TorchAoJSONEncoder(json.JSONEncoder):
54+
def default(self, obj):
55+
if isinstance(obj, MappingType):
56+
return obj.name
57+
return super().default(obj)
58+
59+
5060
@dataclass
5161
class QuantizationConfigMixin:
5262
"""
@@ -673,4 +683,6 @@ def __repr__(self):
673683
```
674684
"""
675685
config_dict = self.to_dict()
676-
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
686+
return (
687+
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
688+
)

tests/quantization/torchao/test_torchao.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def forward(self, input, *args, **kwargs):
7676
if is_torchao_available():
7777
from torchao.dtypes import AffineQuantizedTensor
7878
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
79+
from torchao.quantization.quant_primitives import MappingType
7980
from torchao.utils import get_model_size_in_bytes
8081

8182

@@ -122,6 +123,19 @@ def test_repr(self):
122123
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
123124
self.assertEqual(quantization_repr, expected_repr)
124125

126+
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
127+
expected_repr = """TorchAoConfig {
128+
"modules_to_not_convert": null,
129+
"quant_method": "torchao",
130+
"quant_type": "int4dq",
131+
"quant_type_kwargs": {
132+
"act_mapping_type": "SYMMETRIC",
133+
"group_size": 64
134+
}
135+
}""".replace(" ", "").replace("\n", "")
136+
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
137+
self.assertEqual(quantization_repr, expected_repr)
138+
125139

126140
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
127141
@require_torch

0 commit comments

Comments
 (0)