@@ -15,7 +15,7 @@ import {
1515} from 'features/nodes/util/graph/constants' ;
1616import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice' ;
1717import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs' ;
18- import { size , sumBy } from 'lodash-es' ;
18+ import { size } from 'lodash-es' ;
1919import { imagesApi } from 'services/api/endpoints/images' ;
2020import type { CollectInvocation , Edge , IPAdapterInvocation , NonNullableGraph , S } from 'services/api/types' ;
2121import { assert } from 'tsafe' ;
@@ -39,6 +39,16 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
3939 return hasTextPrompt || hasIPAdapter ;
4040 } ) ;
4141
42+ const regionalIPAdapters = selectAllIPAdapters ( state . controlAdapters ) . filter (
43+ ( { id, model, controlImage, isEnabled } ) => {
44+ const hasModel = Boolean ( model ) ;
45+ const doesBaseMatch = model ?. base === state . generation . model ?. base ;
46+ const hasControlImage = controlImage ;
47+ const isRegional = layers . some ( ( l ) => l . ipAdapterIds . includes ( id ) ) ;
48+ return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional ;
49+ }
50+ ) ;
51+
4252 const layerIds = layers . map ( ( l ) => l . id ) ;
4353 const blobs = await getRegionalPromptLayerBlobs ( layerIds ) ;
4454 assert ( size ( blobs ) === size ( layerIds ) , 'Mismatch between layer IDs and blobs' ) ;
@@ -105,7 +115,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
105115 } ,
106116 } ) ;
107117
108- if ( ! graph . nodes [ IP_ADAPTER_COLLECT ] && sumBy ( layers , ( l ) => l . ipAdapterIds . length ) > 0 ) {
118+ if ( ! graph . nodes [ IP_ADAPTER_COLLECT ] && regionalIPAdapters . length > 0 ) {
109119 const ipAdapterCollectNode : CollectInvocation = {
110120 id : IP_ADAPTER_COLLECT ,
111121 type : 'collect' ,
@@ -284,8 +294,16 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
284294 }
285295
286296 for ( const ipAdapterId of layer . ipAdapterIds ) {
287- const ipAdapter = selectAllIPAdapters ( state . controlAdapters ) . find ( ( ca ) => ca . id === ipAdapterId ) ;
288- console . log ( ipAdapter ) ;
297+ const ipAdapter = selectAllIPAdapters ( state . controlAdapters )
298+ . filter ( ( { id, model, controlImage, isEnabled } ) => {
299+ const hasModel = Boolean ( model ) ;
300+ const doesBaseMatch = model ?. base === state . generation . model ?. base ;
301+ const hasControlImage = controlImage ;
302+ const isRegional = layers . some ( ( l ) => l . ipAdapterIds . includes ( id ) ) ;
303+ return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional ;
304+ } )
305+ . find ( ( ca ) => ca . id === ipAdapterId ) ;
306+
289307 if ( ! ipAdapter ?. model ) {
290308 return ;
291309 }
0 commit comments