Skip to content

Commit 5b69403

Browse files
authored
Merge branch 'main' into copilot/add-unload-model-option
2 parents 83deb02 + ac245cb commit 5b69403

File tree

24 files changed

+762
-109
lines changed

24 files changed

+762
-109
lines changed

invokeai/app/invocations/fields.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,12 @@ class BoardField(BaseModel):
243243
board_id: str = Field(description="The id of the board")
244244

245245

246+
class StylePresetField(BaseModel):
247+
"""A style preset primitive field"""
248+
249+
style_preset_id: str = Field(description="The id of the style preset")
250+
251+
246252
class DenoiseMaskField(BaseModel):
247253
"""An inpaint mask field"""
248254

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
2+
from invokeai.app.invocations.fields import InputField, OutputField, StylePresetField, UIComponent
3+
from invokeai.app.services.shared.invocation_context import InvocationContext
4+
5+
6+
@invocation_output("prompt_template_output")
7+
class PromptTemplateOutput(BaseInvocationOutput):
8+
"""Output for the Prompt Template node"""
9+
10+
positive_prompt: str = OutputField(description="The positive prompt with the template applied")
11+
negative_prompt: str = OutputField(description="The negative prompt with the template applied")
12+
13+
14+
@invocation(
15+
"prompt_template",
16+
title="Prompt Template",
17+
tags=["prompt", "template", "style", "preset"],
18+
category="prompt",
19+
version="1.0.0",
20+
)
21+
class PromptTemplateInvocation(BaseInvocation):
22+
"""Applies a Style Preset template to positive and negative prompts.
23+
24+
Select a Style Preset and provide positive/negative prompts. The node replaces
25+
{prompt} placeholders in the template with your input prompts.
26+
"""
27+
28+
style_preset: StylePresetField = InputField(
29+
description="The Style Preset to use as a template",
30+
)
31+
positive_prompt: str = InputField(
32+
default="",
33+
description="The positive prompt to insert into the template's {prompt} placeholder",
34+
ui_component=UIComponent.Textarea,
35+
)
36+
negative_prompt: str = InputField(
37+
default="",
38+
description="The negative prompt to insert into the template's {prompt} placeholder",
39+
ui_component=UIComponent.Textarea,
40+
)
41+
42+
def invoke(self, context: InvocationContext) -> PromptTemplateOutput:
43+
# Fetch the style preset from the database
44+
style_preset = context._services.style_preset_records.get(self.style_preset.style_preset_id)
45+
46+
# Get the template prompts
47+
positive_template = style_preset.preset_data.positive_prompt
48+
negative_template = style_preset.preset_data.negative_prompt
49+
50+
# Replace {prompt} placeholder with the input prompts
51+
rendered_positive = positive_template.replace("{prompt}", self.positive_prompt)
52+
rendered_negative = negative_template.replace("{prompt}", self.negative_prompt)
53+
54+
return PromptTemplateOutput(
55+
positive_prompt=rendered_positive,
56+
negative_prompt=rendered_negative,
57+
)

invokeai/backend/model_manager/configs/lora.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,16 @@ def _validate_base(cls, mod: ModelOnDisk) -> None:
150150

151151
@classmethod
152152
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
153-
# First rule out ControlLoRA and Diffusers LoRA
153+
# First rule out ControlLoRA
154154
flux_format = _get_flux_lora_format(mod)
155155
if flux_format in [FluxLoRAFormat.Control]:
156156
raise NotAMatchError("model looks like Control LoRA")
157157

158+
# If it's a recognized Flux LoRA format (Kohya, Diffusers, OneTrainer, AIToolkit, XLabs, etc.),
159+
# it's valid and we skip the heuristic check
160+
if flux_format is not None:
161+
return
162+
158163
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
159164
# Some main models have these keys, likely due to the creator merging in a LoRA.
160165
has_key_with_lora_prefix = state_dict_has_any_keys_starting_with(

invokeai/backend/model_manager/load/model_loaders/lora.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
is_state_dict_likely_in_flux_onetrainer_format,
4242
lora_model_from_flux_onetrainer_state_dict,
4343
)
44+
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
45+
is_state_dict_likely_in_flux_xlabs_format,
46+
lora_model_from_flux_xlabs_state_dict,
47+
)
4448
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
4549
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
4650
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict
@@ -118,6 +122,8 @@ def _load_model(
118122
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
119123
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
120124
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
125+
elif is_state_dict_likely_in_flux_xlabs_format(state_dict=state_dict):
126+
model = lora_model_from_flux_xlabs_state_dict(state_dict=state_dict)
121127
else:
122128
raise ValueError("LoRA model is in unsupported FLUX format")
123129
else:

invokeai/backend/model_manager/taxonomy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class FluxLoRAFormat(str, Enum):
171171
OneTrainer = "flux.onetrainer"
172172
Control = "flux.control"
173173
AIToolkit = "flux.aitoolkit"
174+
XLabs = "flux.xlabs"
174175

175176

176177
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import re
2+
from typing import Any, Dict
3+
4+
import torch
5+
6+
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
7+
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
8+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
9+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
10+
11+
# A regex pattern that matches all of the transformer keys in the xlabs FLUX LoRA format.
12+
# Example keys:
13+
# double_blocks.0.processor.qkv_lora1.down.weight
14+
# double_blocks.0.processor.qkv_lora1.up.weight
15+
# double_blocks.0.processor.proj_lora1.down.weight
16+
# double_blocks.0.processor.proj_lora1.up.weight
17+
# double_blocks.0.processor.qkv_lora2.down.weight
18+
# double_blocks.0.processor.proj_lora2.up.weight
19+
FLUX_XLABS_KEY_REGEX = r"double_blocks\.(\d+)\.processor\.(qkv|proj)_lora([12])\.(down|up)\.weight"
20+
21+
22+
def is_state_dict_likely_in_flux_xlabs_format(state_dict: dict[str | int, Any]) -> bool:
23+
"""Checks if the provided state dict is likely in the xlabs FLUX LoRA format.
24+
25+
The xlabs format is characterized by keys matching the pattern:
26+
double_blocks.{block_idx}.processor.{qkv|proj}_lora{1|2}.{down|up}.weight
27+
28+
Where:
29+
- lora1 corresponds to the image attention stream (img_attn)
30+
- lora2 corresponds to the text attention stream (txt_attn)
31+
"""
32+
if not state_dict:
33+
return False
34+
35+
# Check that all keys match the xlabs pattern
36+
for key in state_dict.keys():
37+
if not isinstance(key, str):
38+
continue
39+
if not re.match(FLUX_XLABS_KEY_REGEX, key):
40+
return False
41+
42+
# Ensure we have at least some valid keys
43+
return any(isinstance(k, str) and re.match(FLUX_XLABS_KEY_REGEX, k) for k in state_dict.keys())
44+
45+
46+
def lora_model_from_flux_xlabs_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
47+
"""Converts an xlabs FLUX LoRA state dict to the InvokeAI ModelPatchRaw format.
48+
49+
The xlabs format uses:
50+
- lora1 for image attention stream (img_attn)
51+
- lora2 for text attention stream (txt_attn)
52+
- qkv for query/key/value projection
53+
- proj for output projection
54+
55+
Key mapping:
56+
- double_blocks.X.processor.qkv_lora1 -> double_blocks.X.img_attn.qkv
57+
- double_blocks.X.processor.proj_lora1 -> double_blocks.X.img_attn.proj
58+
- double_blocks.X.processor.qkv_lora2 -> double_blocks.X.txt_attn.qkv
59+
- double_blocks.X.processor.proj_lora2 -> double_blocks.X.txt_attn.proj
60+
"""
61+
# Group keys by layer (without the .down.weight/.up.weight suffix)
62+
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
63+
64+
for key, value in state_dict.items():
65+
match = re.match(FLUX_XLABS_KEY_REGEX, key)
66+
if not match:
67+
raise ValueError(f"Key '{key}' does not match the expected pattern for xlabs FLUX LoRA weights.")
68+
69+
block_idx = match.group(1)
70+
component = match.group(2) # qkv or proj
71+
lora_stream = match.group(3) # 1 or 2
72+
direction = match.group(4) # down or up
73+
74+
# Map lora1 -> img_attn, lora2 -> txt_attn
75+
attn_type = "img_attn" if lora_stream == "1" else "txt_attn"
76+
77+
# Create the InvokeAI-style layer key
78+
layer_key = f"double_blocks.{block_idx}.{attn_type}.{component}"
79+
80+
if layer_key not in grouped_state_dict:
81+
grouped_state_dict[layer_key] = {}
82+
83+
# Map down/up to lora_down/lora_up
84+
param_name = f"lora_{direction}.weight"
85+
grouped_state_dict[layer_key][param_name] = value
86+
87+
# Create LoRA layers
88+
layers: dict[str, BaseLayerPatch] = {}
89+
for layer_key, layer_state_dict in grouped_state_dict.items():
90+
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
91+
92+
return ModelPatchRaw(layers=layers)

invokeai/backend/patches/lora_conversions/formats.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
1515
is_state_dict_likely_in_flux_onetrainer_format,
1616
)
17+
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
18+
is_state_dict_likely_in_flux_xlabs_format,
19+
)
1720

1821

1922
def flux_format_from_state_dict(
@@ -30,5 +33,7 @@ def flux_format_from_state_dict(
3033
return FluxLoRAFormat.Control
3134
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
3235
return FluxLoRAFormat.AIToolkit
36+
elif is_state_dict_likely_in_flux_xlabs_format(state_dict):
37+
return FluxLoRAFormat.XLabs
3338
else:
3439
return None

invokeai/frontend/web/public/locales/en.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,7 @@
19851985
"allLoaded": "All Workflows Loaded",
19861986
"searchPlaceholder": "Search by name, description or tags",
19871987
"filterByTags": "Filter by Tags",
1988+
"tags": "Tags",
19881989
"yourWorkflows": "Your Workflows",
19891990
"recentlyOpened": "Recently Opened",
19901991
"noRecentWorkflows": "No Recent Workflows",
@@ -2676,7 +2677,9 @@
26762677
"useForTemplate": "Use For Prompt Template",
26772678
"viewList": "View Template List",
26782679
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box.",
2679-
"togglePromptPreviews": "Toggle Prompt Previews"
2680+
"togglePromptPreviews": "Toggle Prompt Previews",
2681+
"selectPreset": "Select Style Preset",
2682+
"noMatchingPresets": "No matching presets"
26802683
},
26812684

26822685
"ui": {

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ import {
5555
isStringFieldInputTemplate,
5656
isStringGeneratorFieldInputInstance,
5757
isStringGeneratorFieldInputTemplate,
58+
isStylePresetFieldInputInstance,
59+
isStylePresetFieldInputTemplate,
5860
} from 'features/nodes/types/field';
5961
import type { NodeFieldElement } from 'features/nodes/types/workflow';
6062
import { memo } from 'react';
@@ -67,6 +69,7 @@ import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
6769
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
6870
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
6971
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
72+
import StylePresetFieldInputComponent from './inputs/StylePresetFieldInputComponent';
7073

7174
type Props = {
7275
nodeId: string;
@@ -206,6 +209,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
206209
return <BoardFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
207210
}
208211

212+
if (isStylePresetFieldInputTemplate(template)) {
213+
if (!isStylePresetFieldInputInstance(field)) {
214+
return null;
215+
}
216+
return <StylePresetFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
217+
}
218+
209219
if (isModelIdentifierFieldInputTemplate(template)) {
210220
if (!isModelIdentifierFieldInputInstance(field)) {
211221
return null;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
2+
import { Combobox } from '@invoke-ai/ui-library';
3+
import { useAppDispatch } from 'app/store/storeHooks';
4+
import { fieldStylePresetValueChanged } from 'features/nodes/store/nodesSlice';
5+
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
6+
import type { StylePresetFieldInputInstance, StylePresetFieldInputTemplate } from 'features/nodes/types/field';
7+
import { memo, useCallback, useMemo } from 'react';
8+
import { useTranslation } from 'react-i18next';
9+
import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';
10+
11+
import type { FieldComponentProps } from './types';
12+
13+
const StylePresetFieldInputComponent = (
14+
props: FieldComponentProps<StylePresetFieldInputInstance, StylePresetFieldInputTemplate>
15+
) => {
16+
const { nodeId, field } = props;
17+
const dispatch = useAppDispatch();
18+
const { t } = useTranslation();
19+
const { data: stylePresets, isLoading } = useListStylePresetsQuery();
20+
21+
const options = useMemo<ComboboxOption[]>(() => {
22+
const _options: ComboboxOption[] = [];
23+
if (stylePresets) {
24+
for (const preset of stylePresets) {
25+
_options.push({
26+
label: preset.name,
27+
value: preset.id,
28+
});
29+
}
30+
}
31+
return _options;
32+
}, [stylePresets]);
33+
34+
const onChange = useCallback<ComboboxOnChange>(
35+
(v) => {
36+
if (!v) {
37+
return;
38+
}
39+
40+
dispatch(
41+
fieldStylePresetValueChanged({
42+
nodeId,
43+
fieldName: field.name,
44+
value: { style_preset_id: v.value },
45+
})
46+
);
47+
},
48+
[dispatch, field.name, nodeId]
49+
);
50+
51+
const value = useMemo(() => {
52+
const _value = field.value;
53+
if (!_value) {
54+
return null;
55+
}
56+
return options.find((o) => o.value === _value.style_preset_id) ?? null;
57+
}, [field.value, options]);
58+
59+
const noOptionsMessage = useCallback(() => t('stylePresets.noMatchingPresets'), [t]);
60+
61+
return (
62+
<Combobox
63+
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
64+
value={value}
65+
options={options}
66+
onChange={onChange}
67+
placeholder={isLoading ? t('common.loading') : t('stylePresets.selectPreset')}
68+
noOptionsMessage={noOptionsMessage}
69+
/>
70+
);
71+
};
72+
73+
export default memo(StylePresetFieldInputComponent);

0 commit comments

Comments
 (0)