Skip to content

Commit 01100a2

Browse files
fix(ui): check for ref image config compatibility for flux kontext dev
1 parent ce2e6d8 commit 01100a2

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
33
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
44
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
55
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
6+
import { isFluxKontextReferenceImageConfig } from 'features/controlLayers/store/types';
7+
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
68
import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill';
79
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
810
import { addFLUXReduxes } from 'features/nodes/util/graph/generation/addFLUXRedux';
@@ -88,10 +90,9 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
8890
}
8991

9092
const isFluxKontextDev = model.name?.toLowerCase().includes('kontext');
91-
const fluxKontextDevConditioning = refImages.entities[0]?.config.image?.image_name;
9293
if (isFluxKontextDev) {
9394
if (generationMode !== 'txt2img') {
94-
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'FLUX Kontext' }));
95+
throw new UnsupportedGenerationModeError(t('toast.fluxKontextIncompatibleGenerationMode'));
9596
}
9697

9798
guidance = 30;
@@ -136,15 +137,28 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
136137
id: getPrefixedId('flux_vae_decode'),
137138
});
138139

139-
if (isFluxKontextDev && fluxKontextDevConditioning) {
140-
const kontextConditioning = g.addNode({
141-
type: 'flux_kontext',
142-
id: getPrefixedId('flux_kontext'),
143-
image: {
144-
image_name: fluxKontextDevConditioning,
145-
},
146-
});
147-
g.addEdge(kontextConditioning, 'kontext_cond', denoise, 'kontext_conditioning');
140+
if (isFluxKontextDev) {
141+
const validFLUXKontextConfigs = selectRefImagesSlice(state)
142+
.entities.filter((entity) => entity.isEnabled)
143+
.filter((entity) => isFluxKontextReferenceImageConfig(entity.config))
144+
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
145+
146+
// FLUX Kontext supports only a single conditioning image - we'll just take the first one.
147+
// In the future, we can explore concatenating multiple conditioning images in image or latent space.
148+
const firstValidFLUXKontextConfig = validFLUXKontextConfigs[0];
149+
150+
if (firstValidFLUXKontextConfig) {
151+
const { image } = firstValidFLUXKontextConfig.config;
152+
153+
assert(image, 'getGlobalReferenceImageWarnings checks if the image is there, this should never raise');
154+
155+
const kontextConditioning = g.addNode({
156+
type: 'flux_kontext',
157+
id: getPrefixedId('flux_kontext'),
158+
image,
159+
});
160+
g.addEdge(kontextConditioning, 'kontext_cond', denoise, 'kontext_conditioning');
161+
}
148162
}
149163

150164
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');

0 commit comments

Comments
 (0)