Skip to content

Commit 23b0d2e

Browse files
authored
Add front end support for type matching (#6582)
This PR implements front end logic to handle MatchType inputs and outputs. See comfyanonymous/ComfyUI#10644 This allows for the implementation of nodes such as a "switch node" where input types change based on the connections made. ![switch-node](https://github.com/user-attachments/assets/090515ba-484c-4295-b7b3-204b0c72fc4a) As part of this implementation, significant cleanup is being performed in the reroute code. Extra testing will be required to make sure these changes don't introduce regressions. ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-6582-Add-front-end-support-for-type-matching-2a16d73d36508189b042cd23f82a332e) by [Unito](https://www.unito.io)
1 parent cfbd536 commit 23b0d2e

File tree

9 files changed

+364
-216
lines changed

9 files changed

+364
-216
lines changed

src/extensions/core/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import './groupNodeManage'
1010
import './groupOptions'
1111
import './load3d'
1212
import './maskeditor'
13+
import './matchType'
1314
import './nodeTemplates'
1415
import './noteNode'
1516
import './previewAny'

src/extensions/core/matchType.ts

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import { without } from 'es-toolkit'
2+
3+
import { useChainCallback } from '@/composables/functional/useChainCallback'
4+
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
5+
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
6+
import type { LLink } from '@/lib/litegraph/src/LLink'
7+
import type { ISlotType } from '@/lib/litegraph/src/interfaces'
8+
import { app } from '@/scripts/app'
9+
10+
const MATCH_TYPE = 'COMFY_MATCHTYPE_V3'
11+
12+
app.registerExtension({
13+
name: 'Comfy.MatchType',
14+
beforeRegisterNodeDef(nodeType, nodeData) {
15+
const inputs = {
16+
...nodeData.input?.required,
17+
...nodeData.input?.optional
18+
}
19+
if (!Object.values(inputs).some((w) => w[0] === MATCH_TYPE)) return
20+
nodeType.prototype.onNodeCreated = useChainCallback(
21+
nodeType.prototype.onNodeCreated,
22+
function (this: LGraphNode) {
23+
const inputGroups: Record<string, [string, ISlotType][]> = {}
24+
const outputGroups: Record<string, number[]> = {}
25+
for (const input of this.inputs) {
26+
if (input.type !== MATCH_TYPE) continue
27+
const template = inputs[input.name][1]?.template
28+
if (!template) continue
29+
input.type = template.allowed_types ?? '*'
30+
inputGroups[template.template_id] ??= []
31+
inputGroups[template.template_id].push([input.name, input.type])
32+
}
33+
this.outputs.forEach((output, i) => {
34+
if (output.type !== MATCH_TYPE) return
35+
const id = nodeData.output_matchtypes?.[i]
36+
if (id == undefined) return
37+
outputGroups[id] ??= []
38+
outputGroups[id].push(i)
39+
})
40+
for (const groupId in inputGroups) {
41+
addConnectionGroup(this, inputGroups[groupId], outputGroups[groupId])
42+
}
43+
}
44+
)
45+
}
46+
})
47+
function addConnectionGroup(
48+
node: LGraphNode,
49+
inputPairs: [string, ISlotType][],
50+
outputs?: number[]
51+
) {
52+
const connectedTypes: ISlotType[] = new Array(inputPairs.length).fill('*')
53+
node.onConnectionsChange = useChainCallback(
54+
node.onConnectionsChange,
55+
function (
56+
this: LGraphNode,
57+
contype: ISlotType,
58+
slot: number,
59+
iscon: boolean,
60+
linf: LLink | null | undefined
61+
) {
62+
const input = this.inputs[slot]
63+
if (contype !== LiteGraph.INPUT || !this.graph || !input) return
64+
const pairIndex = inputPairs.findIndex(([name]) => name === input.name)
65+
if (pairIndex == -1) return
66+
connectedTypes[pairIndex] = inputPairs[pairIndex][1]
67+
if (iscon && linf) {
68+
const { output, subgraphInput } = linf.resolve(this.graph)
69+
const connectingType = (output ?? subgraphInput)?.type
70+
if (connectingType)
71+
linf.type = connectedTypes[pairIndex] = connectingType
72+
}
73+
//An input slot can accept a connection that is
74+
// - Compatible with original type
75+
// - Compatible with all other input types
76+
//An output slot can output
77+
// - Only what every input can output
78+
for (let i = 0; i < inputPairs.length; i++) {
79+
//NOTE: This isn't great. Originally, I kept direct references to each
80+
//input, but these were becoming orphaned
81+
const input = this.inputs.find((inp) => inp.name === inputPairs[i][0])
82+
if (!input) continue
83+
const otherConnected = [...connectedTypes]
84+
otherConnected.splice(i, 1)
85+
const validType = combineTypes(...otherConnected, inputPairs[i][1])
86+
if (!validType) throw new Error('invalid connection')
87+
input.type = validType
88+
}
89+
if (outputs) {
90+
const outputType = combineTypes(...connectedTypes)
91+
if (!outputType) throw new Error('invalid connection')
92+
changeOutputType(this, outputType, outputs)
93+
}
94+
}
95+
)
96+
}
97+
98+
function changeOutputType(
99+
node: LGraphNode,
100+
combinedType: ISlotType,
101+
outputs: number[]
102+
) {
103+
if (!node.graph) return
104+
for (const index of outputs) {
105+
if (node.outputs[index].type === combinedType) continue
106+
node.outputs[index].type = combinedType
107+
108+
//check and potentially remove links
109+
for (let link_id of node.outputs[index].links ?? []) {
110+
let link = node.graph.links[link_id]
111+
if (!link) continue
112+
const { input, inputNode, subgraphOutput } = link.resolve(node.graph)
113+
const inputType = (input ?? subgraphOutput)?.type
114+
if (!inputType) continue
115+
const keep = LiteGraph.isValidConnection(combinedType, inputType)
116+
if (!keep && subgraphOutput) subgraphOutput.disconnect()
117+
else if (!keep && inputNode) inputNode.disconnectInput(link.target_slot)
118+
if (input && inputNode?.onConnectionsChange)
119+
inputNode.onConnectionsChange(
120+
LiteGraph.INPUT,
121+
link.target_slot,
122+
keep,
123+
link,
124+
input
125+
)
126+
}
127+
app.canvas.setDirty(true, true)
128+
}
129+
}
130+
function isStrings(types: ISlotType[]): types is string[] {
131+
return !types.some((t) => typeof t !== 'string')
132+
}
133+
134+
function combineTypes(...types: ISlotType[]): ISlotType | undefined {
135+
if (!isStrings(types)) return undefined
136+
137+
const withoutWildcards = without(types, '*')
138+
if (withoutWildcards.length === 0) return '*'
139+
140+
const typeLists: string[][] = withoutWildcards.map((type) => type.split(','))
141+
142+
const combinedTypes = intersection(...typeLists)
143+
if (combinedTypes.length === 0) return undefined
144+
145+
return combinedTypes.join(',')
146+
}
147+
function intersection(...sets: string[][]): string[] {
148+
const itemCounts: Record<string, number> = {}
149+
for (const set of sets)
150+
for (const item of new Set(set))
151+
itemCounts[item] = (itemCounts[item] ?? 0) + 1
152+
return Object.entries(itemCounts)
153+
.filter(([, count]) => count == sets.length)
154+
.map(([key]) => key)
155+
}

0 commit comments

Comments
 (0)