Skip to content

Commit c309d0e

Browse files
authored
Merge branch 'main' into feat/z-image-regional-guidance
2 parents 2795ab3 + ddb85ca commit c309d0e

File tree

16 files changed

+970
-655
lines changed

16 files changed

+970
-655
lines changed

invokeai/backend/model_manager/configs/lora.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,16 @@ def _validate_base(cls, mod: ModelOnDisk) -> None:
150150

151151
@classmethod
152152
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
153-
# First rule out ControlLoRA and Diffusers LoRA
153+
# First rule out ControlLoRA
154154
flux_format = _get_flux_lora_format(mod)
155155
if flux_format in [FluxLoRAFormat.Control]:
156156
raise NotAMatchError("model looks like Control LoRA")
157157

158+
# If it's a recognized Flux LoRA format (Kohya, Diffusers, OneTrainer, AIToolkit, XLabs, etc.),
159+
# it's valid and we skip the heuristic check
160+
if flux_format is not None:
161+
return
162+
158163
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
159164
# Some main models have these keys, likely due to the creator merging in a LoRA.
160165
has_key_with_lora_prefix = state_dict_has_any_keys_starting_with(

invokeai/backend/model_manager/load/model_loaders/lora.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
is_state_dict_likely_in_flux_onetrainer_format,
4242
lora_model_from_flux_onetrainer_state_dict,
4343
)
44+
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
45+
is_state_dict_likely_in_flux_xlabs_format,
46+
lora_model_from_flux_xlabs_state_dict,
47+
)
4448
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
4549
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
4650
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict
@@ -118,6 +122,8 @@ def _load_model(
118122
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
119123
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
120124
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
125+
elif is_state_dict_likely_in_flux_xlabs_format(state_dict=state_dict):
126+
model = lora_model_from_flux_xlabs_state_dict(state_dict=state_dict)
121127
else:
122128
raise ValueError("LoRA model is in unsupported FLUX format")
123129
else:

invokeai/backend/model_manager/taxonomy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class FluxLoRAFormat(str, Enum):
171171
OneTrainer = "flux.onetrainer"
172172
Control = "flux.control"
173173
AIToolkit = "flux.aitoolkit"
174+
XLabs = "flux.xlabs"
174175

175176

176177
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import re
2+
from typing import Any, Dict
3+
4+
import torch
5+
6+
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
7+
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
8+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
9+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
10+
11+
# A regex pattern that matches all of the transformer keys in the xlabs FLUX LoRA format.
12+
# Example keys:
13+
# double_blocks.0.processor.qkv_lora1.down.weight
14+
# double_blocks.0.processor.qkv_lora1.up.weight
15+
# double_blocks.0.processor.proj_lora1.down.weight
16+
# double_blocks.0.processor.proj_lora1.up.weight
17+
# double_blocks.0.processor.qkv_lora2.down.weight
18+
# double_blocks.0.processor.proj_lora2.up.weight
19+
FLUX_XLABS_KEY_REGEX = r"double_blocks\.(\d+)\.processor\.(qkv|proj)_lora([12])\.(down|up)\.weight"
20+
21+
22+
def is_state_dict_likely_in_flux_xlabs_format(state_dict: dict[str | int, Any]) -> bool:
23+
"""Checks if the provided state dict is likely in the xlabs FLUX LoRA format.
24+
25+
The xlabs format is characterized by keys matching the pattern:
26+
double_blocks.{block_idx}.processor.{qkv|proj}_lora{1|2}.{down|up}.weight
27+
28+
Where:
29+
- lora1 corresponds to the image attention stream (img_attn)
30+
- lora2 corresponds to the text attention stream (txt_attn)
31+
"""
32+
if not state_dict:
33+
return False
34+
35+
# Check that all keys match the xlabs pattern
36+
for key in state_dict.keys():
37+
if not isinstance(key, str):
38+
continue
39+
if not re.match(FLUX_XLABS_KEY_REGEX, key):
40+
return False
41+
42+
# Ensure we have at least some valid keys
43+
return any(isinstance(k, str) and re.match(FLUX_XLABS_KEY_REGEX, k) for k in state_dict.keys())
44+
45+
46+
def lora_model_from_flux_xlabs_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
47+
"""Converts an xlabs FLUX LoRA state dict to the InvokeAI ModelPatchRaw format.
48+
49+
The xlabs format uses:
50+
- lora1 for image attention stream (img_attn)
51+
- lora2 for text attention stream (txt_attn)
52+
- qkv for query/key/value projection
53+
- proj for output projection
54+
55+
Key mapping:
56+
- double_blocks.X.processor.qkv_lora1 -> double_blocks.X.img_attn.qkv
57+
- double_blocks.X.processor.proj_lora1 -> double_blocks.X.img_attn.proj
58+
- double_blocks.X.processor.qkv_lora2 -> double_blocks.X.txt_attn.qkv
59+
- double_blocks.X.processor.proj_lora2 -> double_blocks.X.txt_attn.proj
60+
"""
61+
# Group keys by layer (without the .down.weight/.up.weight suffix)
62+
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
63+
64+
for key, value in state_dict.items():
65+
match = re.match(FLUX_XLABS_KEY_REGEX, key)
66+
if not match:
67+
raise ValueError(f"Key '{key}' does not match the expected pattern for xlabs FLUX LoRA weights.")
68+
69+
block_idx = match.group(1)
70+
component = match.group(2) # qkv or proj
71+
lora_stream = match.group(3) # 1 or 2
72+
direction = match.group(4) # down or up
73+
74+
# Map lora1 -> img_attn, lora2 -> txt_attn
75+
attn_type = "img_attn" if lora_stream == "1" else "txt_attn"
76+
77+
# Create the InvokeAI-style layer key
78+
layer_key = f"double_blocks.{block_idx}.{attn_type}.{component}"
79+
80+
if layer_key not in grouped_state_dict:
81+
grouped_state_dict[layer_key] = {}
82+
83+
# Map down/up to lora_down/lora_up
84+
param_name = f"lora_{direction}.weight"
85+
grouped_state_dict[layer_key][param_name] = value
86+
87+
# Create LoRA layers
88+
layers: dict[str, BaseLayerPatch] = {}
89+
for layer_key, layer_state_dict in grouped_state_dict.items():
90+
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
91+
92+
return ModelPatchRaw(layers=layers)

invokeai/backend/patches/lora_conversions/formats.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
1515
is_state_dict_likely_in_flux_onetrainer_format,
1616
)
17+
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
18+
is_state_dict_likely_in_flux_xlabs_format,
19+
)
1720

1821

1922
def flux_format_from_state_dict(
@@ -30,5 +33,7 @@ def flux_format_from_state_dict(
3033
return FluxLoRAFormat.Control
3134
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
3235
return FluxLoRAFormat.AIToolkit
36+
elif is_state_dict_likely_in_flux_xlabs_format(state_dict):
37+
return FluxLoRAFormat.XLabs
3338
else:
3439
return None

invokeai/frontend/web/src/common/util/promptAST.test.ts

Lines changed: 73 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,72 +7,76 @@ describe('promptAST', () => {
77
it('should tokenize basic text', () => {
88
const tokens = tokenize('a cat');
99
expect(tokens).toEqual([
10-
{ type: 'word', value: 'a' },
11-
{ type: 'whitespace', value: ' ' },
12-
{ type: 'word', value: 'cat' },
10+
{ type: 'word', value: 'a', start: 0, end: 1 },
11+
{ type: 'whitespace', value: ' ', start: 1, end: 2 },
12+
{ type: 'word', value: 'cat', start: 2, end: 5 },
1313
]);
1414
});
1515

1616
it('should tokenize groups with parentheses', () => {
1717
const tokens = tokenize('(a cat)');
1818
expect(tokens).toEqual([
19-
{ type: 'lparen' },
20-
{ type: 'word', value: 'a' },
21-
{ type: 'whitespace', value: ' ' },
22-
{ type: 'word', value: 'cat' },
23-
{ type: 'rparen' },
19+
{ type: 'lparen', start: 0, end: 1 },
20+
{ type: 'word', value: 'a', start: 1, end: 2 },
21+
{ type: 'whitespace', value: ' ', start: 2, end: 3 },
22+
{ type: 'word', value: 'cat', start: 3, end: 6 },
23+
{ type: 'rparen', start: 6, end: 7 },
2424
]);
2525
});
2626

2727
it('should tokenize escaped parentheses', () => {
2828
const tokens = tokenize('\\(medium\\)');
2929
expect(tokens).toEqual([
30-
{ type: 'escaped_paren', value: '(' },
31-
{ type: 'word', value: 'medium' },
32-
{ type: 'escaped_paren', value: ')' },
30+
{ type: 'escaped_paren', value: '(', start: 0, end: 2 },
31+
{ type: 'word', value: 'medium', start: 2, end: 8 },
32+
{ type: 'escaped_paren', value: ')', start: 8, end: 10 },
3333
]);
3434
});
3535

3636
it('should tokenize mixed escaped and unescaped parentheses', () => {
3737
const tokens = tokenize('colored pencil \\(medium\\) (enhanced)');
3838
expect(tokens).toEqual([
39-
{ type: 'word', value: 'colored' },
40-
{ type: 'whitespace', value: ' ' },
41-
{ type: 'word', value: 'pencil' },
42-
{ type: 'whitespace', value: ' ' },
43-
{ type: 'escaped_paren', value: '(' },
44-
{ type: 'word', value: 'medium' },
45-
{ type: 'escaped_paren', value: ')' },
46-
{ type: 'whitespace', value: ' ' },
47-
{ type: 'lparen' },
48-
{ type: 'word', value: 'enhanced' },
49-
{ type: 'rparen' },
39+
{ type: 'word', value: 'colored', start: 0, end: 7 },
40+
{ type: 'whitespace', value: ' ', start: 7, end: 8 },
41+
{ type: 'word', value: 'pencil', start: 8, end: 14 },
42+
{ type: 'whitespace', value: ' ', start: 14, end: 15 },
43+
{ type: 'escaped_paren', value: '(', start: 15, end: 17 },
44+
{ type: 'word', value: 'medium', start: 17, end: 23 },
45+
{ type: 'escaped_paren', value: ')', start: 23, end: 25 },
46+
{ type: 'whitespace', value: ' ', start: 25, end: 26 },
47+
{ type: 'lparen', start: 26, end: 27 },
48+
{ type: 'word', value: 'enhanced', start: 27, end: 35 },
49+
{ type: 'rparen', start: 35, end: 36 },
5050
]);
5151
});
5252

5353
it('should tokenize groups with weights', () => {
5454
const tokens = tokenize('(a cat)1.2');
5555
expect(tokens).toEqual([
56-
{ type: 'lparen' },
57-
{ type: 'word', value: 'a' },
58-
{ type: 'whitespace', value: ' ' },
59-
{ type: 'word', value: 'cat' },
60-
{ type: 'rparen' },
61-
{ type: 'weight', value: 1.2 },
56+
{ type: 'lparen', start: 0, end: 1 },
57+
{ type: 'word', value: 'a', start: 1, end: 2 },
58+
{ type: 'whitespace', value: ' ', start: 2, end: 3 },
59+
{ type: 'word', value: 'cat', start: 3, end: 6 },
60+
{ type: 'rparen', start: 6, end: 7 },
61+
{ type: 'weight', value: 1.2, start: 7, end: 10 },
6262
]);
6363
});
6464

6565
it('should tokenize words with weights', () => {
6666
const tokens = tokenize('cat+');
6767
expect(tokens).toEqual([
68-
{ type: 'word', value: 'cat' },
69-
{ type: 'weight', value: '+' },
68+
{ type: 'word', value: 'cat', start: 0, end: 3 },
69+
{ type: 'weight', value: '+', start: 3, end: 4 },
7070
]);
7171
});
7272

7373
it('should tokenize embeddings', () => {
7474
const tokens = tokenize('<embedding_name>');
75-
expect(tokens).toEqual([{ type: 'lembed' }, { type: 'word', value: 'embedding_name' }, { type: 'rembed' }]);
75+
expect(tokens).toEqual([
76+
{ type: 'lembed', start: 0, end: 1 },
77+
{ type: 'word', value: 'embedding_name', start: 1, end: 15 },
78+
{ type: 'rembed', start: 15, end: 16 },
79+
]);
7680
});
7781
});
7882

@@ -81,9 +85,9 @@ describe('promptAST', () => {
8185
const tokens = tokenize('a cat');
8286
const ast = parseTokens(tokens);
8387
expect(ast).toEqual([
84-
{ type: 'word', text: 'a' },
85-
{ type: 'whitespace', value: ' ' },
86-
{ type: 'word', text: 'cat' },
88+
{ type: 'word', text: 'a', range: { start: 0, end: 1 }, attention: undefined },
89+
{ type: 'whitespace', value: ' ', range: { start: 1, end: 2 } },
90+
{ type: 'word', text: 'cat', range: { start: 2, end: 5 }, attention: undefined },
8791
]);
8892
});
8993

@@ -93,10 +97,12 @@ describe('promptAST', () => {
9397
expect(ast).toEqual([
9498
{
9599
type: 'group',
100+
range: { start: 0, end: 7 },
101+
attention: undefined,
96102
children: [
97-
{ type: 'word', text: 'a' },
98-
{ type: 'whitespace', value: ' ' },
99-
{ type: 'word', text: 'cat' },
103+
{ type: 'word', text: 'a', range: { start: 1, end: 2 }, attention: undefined },
104+
{ type: 'whitespace', value: ' ', range: { start: 2, end: 3 } },
105+
{ type: 'word', text: 'cat', range: { start: 3, end: 6 }, attention: undefined },
100106
],
101107
},
102108
]);
@@ -106,27 +112,29 @@ describe('promptAST', () => {
106112
const tokens = tokenize('\\(medium\\)');
107113
const ast = parseTokens(tokens);
108114
expect(ast).toEqual([
109-
{ type: 'escaped_paren', value: '(' },
110-
{ type: 'word', text: 'medium' },
111-
{ type: 'escaped_paren', value: ')' },
115+
{ type: 'escaped_paren', value: '(', range: { start: 0, end: 2 } },
116+
{ type: 'word', text: 'medium', range: { start: 2, end: 8 }, attention: undefined },
117+
{ type: 'escaped_paren', value: ')', range: { start: 8, end: 10 } },
112118
]);
113119
});
114120

115121
it('should parse mixed escaped and unescaped parentheses', () => {
116122
const tokens = tokenize('colored pencil \\(medium\\) (enhanced)');
117123
const ast = parseTokens(tokens);
118124
expect(ast).toEqual([
119-
{ type: 'word', text: 'colored' },
120-
{ type: 'whitespace', value: ' ' },
121-
{ type: 'word', text: 'pencil' },
122-
{ type: 'whitespace', value: ' ' },
123-
{ type: 'escaped_paren', value: '(' },
124-
{ type: 'word', text: 'medium' },
125-
{ type: 'escaped_paren', value: ')' },
126-
{ type: 'whitespace', value: ' ' },
125+
{ type: 'word', text: 'colored', range: { start: 0, end: 7 }, attention: undefined },
126+
{ type: 'whitespace', value: ' ', range: { start: 7, end: 8 } },
127+
{ type: 'word', text: 'pencil', range: { start: 8, end: 14 }, attention: undefined },
128+
{ type: 'whitespace', value: ' ', range: { start: 14, end: 15 } },
129+
{ type: 'escaped_paren', value: '(', range: { start: 15, end: 17 } },
130+
{ type: 'word', text: 'medium', range: { start: 17, end: 23 }, attention: undefined },
131+
{ type: 'escaped_paren', value: ')', range: { start: 23, end: 25 } },
132+
{ type: 'whitespace', value: ' ', range: { start: 25, end: 26 } },
127133
{
128134
type: 'group',
129-
children: [{ type: 'word', text: 'enhanced' }],
135+
range: { start: 26, end: 36 },
136+
attention: undefined,
137+
children: [{ type: 'word', text: 'enhanced', range: { start: 27, end: 35 }, attention: undefined }],
130138
},
131139
]);
132140
});
@@ -138,10 +146,11 @@ describe('promptAST', () => {
138146
{
139147
type: 'group',
140148
attention: 1.2,
149+
range: { start: 0, end: 10 },
141150
children: [
142-
{ type: 'word', text: 'a' },
143-
{ type: 'whitespace', value: ' ' },
144-
{ type: 'word', text: 'cat' },
151+
{ type: 'word', text: 'a', range: { start: 1, end: 2 }, attention: undefined },
152+
{ type: 'whitespace', value: ' ', range: { start: 2, end: 3 } },
153+
{ type: 'word', text: 'cat', range: { start: 3, end: 6 }, attention: undefined },
145154
],
146155
},
147156
]);
@@ -150,13 +159,13 @@ describe('promptAST', () => {
150159
it('should parse words with attention', () => {
151160
const tokens = tokenize('cat+');
152161
const ast = parseTokens(tokens);
153-
expect(ast).toEqual([{ type: 'word', text: 'cat', attention: '+' }]);
162+
expect(ast).toEqual([{ type: 'word', text: 'cat', attention: '+', range: { start: 0, end: 4 } }]);
154163
});
155164

156165
it('should parse embeddings', () => {
157166
const tokens = tokenize('<embedding_name>');
158167
const ast = parseTokens(tokens);
159-
expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name' }]);
168+
expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name', range: { start: 0, end: 16 } }]);
160169
});
161170
});
162171

@@ -243,19 +252,20 @@ describe('promptAST', () => {
243252

244253
// Should have escaped parens as nodes and a group with attention
245254
expect(ast).toEqual([
246-
{ type: 'word', text: 'portrait' },
247-
{ type: 'whitespace', value: ' ' },
248-
{ type: 'escaped_paren', value: '(' },
249-
{ type: 'word', text: 'realistic' },
250-
{ type: 'escaped_paren', value: ')' },
251-
{ type: 'whitespace', value: ' ' },
255+
{ type: 'word', text: 'portrait', range: { start: 0, end: 8 }, attention: undefined },
256+
{ type: 'whitespace', value: ' ', range: { start: 8, end: 9 } },
257+
{ type: 'escaped_paren', value: '(', range: { start: 9, end: 11 } },
258+
{ type: 'word', text: 'realistic', range: { start: 11, end: 20 }, attention: undefined },
259+
{ type: 'escaped_paren', value: ')', range: { start: 20, end: 22 } },
260+
{ type: 'whitespace', value: ' ', range: { start: 22, end: 23 } },
252261
{
253262
type: 'group',
254263
attention: 1.2,
264+
range: { start: 23, end: 40 },
255265
children: [
256-
{ type: 'word', text: 'high' },
257-
{ type: 'whitespace', value: ' ' },
258-
{ type: 'word', text: 'quality' },
266+
{ type: 'word', text: 'high', range: { start: 24, end: 28 }, attention: undefined },
267+
{ type: 'whitespace', value: ' ', range: { start: 28, end: 29 } },
268+
{ type: 'word', text: 'quality', range: { start: 29, end: 36 }, attention: undefined },
259269
],
260270
},
261271
]);

0 commit comments

Comments
 (0)