Skip to content

Commit ac245cb

Browse files
feat(backend): add support for xlabs Flux LoRA format (#8686)
Add support for loading Flux LoRA models in the xlabs format, which uses keys like `double_blocks.X.processor.{qkv|proj}_lora{1|2}.{down|up}.weight`. The xlabs format maps: - lora1 -> img_attn (image attention stream) - lora2 -> txt_attn (text attention stream) - qkv -> query/key/value projection - proj -> output projection Changes: - Add FluxLoRAFormat.XLabs enum value - Add flux_xlabs_lora_conversion_utils.py with detection and conversion - Update formats.py to detect xlabs format - Update lora.py loader to handle xlabs format - Update model probe to accept recognized Flux LoRA formats - Add unit tests for xlabs format detection and conversion Co-authored-by: Lincoln Stein <[email protected]>
1 parent 5be1e03 commit ac245cb

File tree

7 files changed

+241
-1
lines changed

7 files changed

+241
-1
lines changed

invokeai/backend/model_manager/configs/lora.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,16 @@ def _validate_base(cls, mod: ModelOnDisk) -> None:
150150

151151
@classmethod
152152
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
153-
# First rule out ControlLoRA and Diffusers LoRA
153+
# First rule out ControlLoRA
154154
flux_format = _get_flux_lora_format(mod)
155155
if flux_format in [FluxLoRAFormat.Control]:
156156
raise NotAMatchError("model looks like Control LoRA")
157157

158+
# If it's a recognized Flux LoRA format (Kohya, Diffusers, OneTrainer, AIToolkit, XLabs, etc.),
159+
# it's valid and we skip the heuristic check
160+
if flux_format is not None:
161+
return
162+
158163
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
159164
# Some main models have these keys, likely due to the creator merging in a LoRA.
160165
has_key_with_lora_prefix = state_dict_has_any_keys_starting_with(

invokeai/backend/model_manager/load/model_loaders/lora.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
is_state_dict_likely_in_flux_onetrainer_format,
4242
lora_model_from_flux_onetrainer_state_dict,
4343
)
44+
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
45+
is_state_dict_likely_in_flux_xlabs_format,
46+
lora_model_from_flux_xlabs_state_dict,
47+
)
4448
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
4549
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
4650
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict
@@ -118,6 +122,8 @@ def _load_model(
118122
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
119123
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
120124
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
125+
elif is_state_dict_likely_in_flux_xlabs_format(state_dict=state_dict):
126+
model = lora_model_from_flux_xlabs_state_dict(state_dict=state_dict)
121127
else:
122128
raise ValueError("LoRA model is in unsupported FLUX format")
123129
else:

invokeai/backend/model_manager/taxonomy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class FluxLoRAFormat(str, Enum):
171171
OneTrainer = "flux.onetrainer"
172172
Control = "flux.control"
173173
AIToolkit = "flux.aitoolkit"
174+
XLabs = "flux.xlabs"
174175

175176

176177
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import re
2+
from typing import Any, Dict
3+
4+
import torch
5+
6+
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
7+
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
8+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
9+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
10+
11+
# A regex pattern that matches all of the transformer keys in the xlabs FLUX LoRA format.
12+
# Example keys:
13+
# double_blocks.0.processor.qkv_lora1.down.weight
14+
# double_blocks.0.processor.qkv_lora1.up.weight
15+
# double_blocks.0.processor.proj_lora1.down.weight
16+
# double_blocks.0.processor.proj_lora1.up.weight
17+
# double_blocks.0.processor.qkv_lora2.down.weight
18+
# double_blocks.0.processor.proj_lora2.up.weight
19+
FLUX_XLABS_KEY_REGEX = r"double_blocks\.(\d+)\.processor\.(qkv|proj)_lora([12])\.(down|up)\.weight"
20+
21+
22+
def is_state_dict_likely_in_flux_xlabs_format(state_dict: dict[str | int, Any]) -> bool:
23+
"""Checks if the provided state dict is likely in the xlabs FLUX LoRA format.
24+
25+
The xlabs format is characterized by keys matching the pattern:
26+
double_blocks.{block_idx}.processor.{qkv|proj}_lora{1|2}.{down|up}.weight
27+
28+
Where:
29+
- lora1 corresponds to the image attention stream (img_attn)
30+
- lora2 corresponds to the text attention stream (txt_attn)
31+
"""
32+
if not state_dict:
33+
return False
34+
35+
# Check that all keys match the xlabs pattern
36+
for key in state_dict.keys():
37+
if not isinstance(key, str):
38+
continue
39+
if not re.match(FLUX_XLABS_KEY_REGEX, key):
40+
return False
41+
42+
# Ensure we have at least some valid keys
43+
return any(isinstance(k, str) and re.match(FLUX_XLABS_KEY_REGEX, k) for k in state_dict.keys())
44+
45+
46+
def lora_model_from_flux_xlabs_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
47+
"""Converts an xlabs FLUX LoRA state dict to the InvokeAI ModelPatchRaw format.
48+
49+
The xlabs format uses:
50+
- lora1 for image attention stream (img_attn)
51+
- lora2 for text attention stream (txt_attn)
52+
- qkv for query/key/value projection
53+
- proj for output projection
54+
55+
Key mapping:
56+
- double_blocks.X.processor.qkv_lora1 -> double_blocks.X.img_attn.qkv
57+
- double_blocks.X.processor.proj_lora1 -> double_blocks.X.img_attn.proj
58+
- double_blocks.X.processor.qkv_lora2 -> double_blocks.X.txt_attn.qkv
59+
- double_blocks.X.processor.proj_lora2 -> double_blocks.X.txt_attn.proj
60+
"""
61+
# Group keys by layer (without the .down.weight/.up.weight suffix)
62+
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
63+
64+
for key, value in state_dict.items():
65+
match = re.match(FLUX_XLABS_KEY_REGEX, key)
66+
if not match:
67+
raise ValueError(f"Key '{key}' does not match the expected pattern for xlabs FLUX LoRA weights.")
68+
69+
block_idx = match.group(1)
70+
component = match.group(2) # qkv or proj
71+
lora_stream = match.group(3) # 1 or 2
72+
direction = match.group(4) # down or up
73+
74+
# Map lora1 -> img_attn, lora2 -> txt_attn
75+
attn_type = "img_attn" if lora_stream == "1" else "txt_attn"
76+
77+
# Create the InvokeAI-style layer key
78+
layer_key = f"double_blocks.{block_idx}.{attn_type}.{component}"
79+
80+
if layer_key not in grouped_state_dict:
81+
grouped_state_dict[layer_key] = {}
82+
83+
# Map down/up to lora_down/lora_up
84+
param_name = f"lora_{direction}.weight"
85+
grouped_state_dict[layer_key][param_name] = value
86+
87+
# Create LoRA layers
88+
layers: dict[str, BaseLayerPatch] = {}
89+
for layer_key, layer_state_dict in grouped_state_dict.items():
90+
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
91+
92+
return ModelPatchRaw(layers=layers)

invokeai/backend/patches/lora_conversions/formats.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
1515
is_state_dict_likely_in_flux_onetrainer_format,
1616
)
17+
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
18+
is_state_dict_likely_in_flux_xlabs_format,
19+
)
1720

1821

1922
def flux_format_from_state_dict(
@@ -30,5 +33,7 @@ def flux_format_from_state_dict(
3033
return FluxLoRAFormat.Control
3134
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
3235
return FluxLoRAFormat.AIToolkit
36+
elif is_state_dict_likely_in_flux_xlabs_format(state_dict):
37+
return FluxLoRAFormat.XLabs
3338
else:
3439
return None
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# A sample state dict in the xlabs FLUX LoRA format.
2+
# The xlabs format uses:
3+
# - lora1 for image attention stream (img_attn)
4+
# - lora2 for text attention stream (txt_attn)
5+
# - qkv for query/key/value projection
6+
# - proj for output projection
7+
state_dict_keys = {
8+
"double_blocks.0.processor.proj_lora1.down.weight": [16, 3072],
9+
"double_blocks.0.processor.proj_lora1.up.weight": [3072, 16],
10+
"double_blocks.0.processor.proj_lora2.down.weight": [16, 3072],
11+
"double_blocks.0.processor.proj_lora2.up.weight": [3072, 16],
12+
"double_blocks.0.processor.qkv_lora1.down.weight": [16, 3072],
13+
"double_blocks.0.processor.qkv_lora1.up.weight": [9216, 16],
14+
"double_blocks.0.processor.qkv_lora2.down.weight": [16, 3072],
15+
"double_blocks.0.processor.qkv_lora2.up.weight": [9216, 16],
16+
"double_blocks.1.processor.proj_lora1.down.weight": [16, 3072],
17+
"double_blocks.1.processor.proj_lora1.up.weight": [3072, 16],
18+
"double_blocks.1.processor.proj_lora2.down.weight": [16, 3072],
19+
"double_blocks.1.processor.proj_lora2.up.weight": [3072, 16],
20+
"double_blocks.1.processor.qkv_lora1.down.weight": [16, 3072],
21+
"double_blocks.1.processor.qkv_lora1.up.weight": [9216, 16],
22+
"double_blocks.1.processor.qkv_lora2.down.weight": [16, 3072],
23+
"double_blocks.1.processor.qkv_lora2.up.weight": [9216, 16],
24+
"double_blocks.10.processor.proj_lora1.down.weight": [16, 3072],
25+
"double_blocks.10.processor.proj_lora1.up.weight": [3072, 16],
26+
"double_blocks.10.processor.proj_lora2.down.weight": [16, 3072],
27+
"double_blocks.10.processor.proj_lora2.up.weight": [3072, 16],
28+
"double_blocks.10.processor.qkv_lora1.down.weight": [16, 3072],
29+
"double_blocks.10.processor.qkv_lora1.up.weight": [9216, 16],
30+
"double_blocks.10.processor.qkv_lora2.down.weight": [16, 3072],
31+
"double_blocks.10.processor.qkv_lora2.up.weight": [9216, 16],
32+
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import accelerate
2+
import pytest
3+
import torch
4+
5+
from invokeai.backend.flux.model import Flux
6+
from invokeai.backend.flux.util import get_flux_transformers_params
7+
from invokeai.backend.model_manager.taxonomy import FluxVariantType
8+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
9+
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
10+
is_state_dict_likely_in_flux_xlabs_format,
11+
lora_model_from_flux_xlabs_state_dict,
12+
)
13+
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
14+
state_dict_keys as flux_diffusers_state_dict_keys,
15+
)
16+
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_format import (
17+
state_dict_keys as flux_kohya_state_dict_keys,
18+
)
19+
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_xlabs_format import (
20+
state_dict_keys as flux_xlabs_state_dict_keys,
21+
)
22+
from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict
23+
24+
25+
def test_is_state_dict_likely_in_flux_xlabs_format_true():
26+
"""Test that is_state_dict_likely_in_flux_xlabs_format() can identify a state dict in the xlabs FLUX LoRA format."""
27+
state_dict = keys_to_mock_state_dict(flux_xlabs_state_dict_keys)
28+
assert is_state_dict_likely_in_flux_xlabs_format(state_dict)
29+
30+
31+
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_kohya_state_dict_keys])
32+
def test_is_state_dict_likely_in_flux_xlabs_format_false(sd_keys: dict[str, list[int]]):
33+
"""Test that is_state_dict_likely_in_flux_xlabs_format() returns False for state dicts in other formats."""
34+
state_dict = keys_to_mock_state_dict(sd_keys)
35+
assert not is_state_dict_likely_in_flux_xlabs_format(state_dict)
36+
37+
38+
def test_lora_model_from_flux_xlabs_state_dict():
39+
"""Test that a ModelPatchRaw can be created from a state dict in the xlabs FLUX LoRA format."""
40+
state_dict = keys_to_mock_state_dict(flux_xlabs_state_dict_keys)
41+
42+
lora_model = lora_model_from_flux_xlabs_state_dict(state_dict)
43+
44+
# Verify the expected layer keys are created
45+
expected_layer_keys = {
46+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.img_attn.proj",
47+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.img_attn.qkv",
48+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.txt_attn.proj",
49+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.txt_attn.qkv",
50+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.img_attn.proj",
51+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.img_attn.qkv",
52+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.txt_attn.proj",
53+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.txt_attn.qkv",
54+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.img_attn.proj",
55+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.img_attn.qkv",
56+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.txt_attn.proj",
57+
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.txt_attn.qkv",
58+
}
59+
60+
assert set(lora_model.layers.keys()) == expected_layer_keys
61+
62+
63+
def test_lora_model_from_flux_xlabs_state_dict_matches_model_keys():
64+
"""Test that the converted xlabs LoRA keys match the actual FLUX model keys."""
65+
state_dict = keys_to_mock_state_dict(flux_xlabs_state_dict_keys)
66+
67+
lora_model = lora_model_from_flux_xlabs_state_dict(state_dict)
68+
69+
# Extract the layer prefixes (without the lora_transformer- prefix)
70+
converted_key_prefixes: list[str] = []
71+
for k in lora_model.layers.keys():
72+
# Remove the transformer prefix
73+
k = k.replace(FLUX_LORA_TRANSFORMER_PREFIX, "")
74+
converted_key_prefixes.append(k)
75+
76+
# Initialize a FLUX model on the meta device.
77+
with accelerate.init_empty_weights():
78+
model = Flux(get_flux_transformers_params(FluxVariantType.Schnell))
79+
model_keys = set(model.state_dict().keys())
80+
81+
# Assert that the converted keys match prefixes in the actual model.
82+
for converted_key_prefix in converted_key_prefixes:
83+
found_match = False
84+
for model_key in model_keys:
85+
if model_key.startswith(converted_key_prefix):
86+
found_match = True
87+
break
88+
if not found_match:
89+
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
90+
91+
92+
def test_lora_model_from_flux_xlabs_state_dict_error():
93+
"""Test that an error is raised if the input state_dict contains unexpected keys."""
94+
state_dict = {
95+
"unexpected_key.down.weight": torch.empty(1),
96+
}
97+
98+
with pytest.raises(ValueError):
99+
lora_model_from_flux_xlabs_state_dict(state_dict)

0 commit comments

Comments
 (0)