@@ -9,6 +9,7 @@ import type {
9
9
Dimensions ,
10
10
Edge ,
11
11
EdgeMarkerType ,
12
+ Element ,
12
13
ElementData ,
13
14
Elements ,
14
15
FlowElements ,
@@ -151,24 +152,63 @@ export function parseEdge(edge: Edge, defaults: Partial<GraphEdge> = {}): GraphE
151
152
return Object . assign ( { } , defaults , edge , { id : edge . id . toString ( ) } ) as GraphEdge
152
153
}
153
154
154
- function getConnectedElements < T extends Elements = FlowElements > (
155
+ function getConnectedElements < T extends Node = Node > (
155
156
nodeOrId : Node | { id : string } | string ,
156
- elements : T ,
157
+ nodes : T [ ] ,
158
+ edges : Edge [ ] ,
157
159
dir : 'source' | 'target' ,
158
- ) : T extends FlowElements ? GraphNode [ ] : Node [ ] {
160
+ ) : T [ ] {
159
161
const id = isString ( nodeOrId ) ? nodeOrId : nodeOrId . id
160
162
163
+ const connectedIds = new Set ( )
164
+
161
165
const origin = dir === 'source' ? 'target' : 'source'
162
- const ids = elements . filter ( ( e ) => isEdge ( e ) && e [ origin ] === id ) . map ( ( e ) => isEdge ( e ) && e [ dir ] )
163
166
164
- return elements . filter ( ( e ) => ids . includes ( e . id ) ) as T extends FlowElements ? GraphNode [ ] : Node [ ]
167
+ edges . forEach ( ( edge ) => {
168
+ if ( edge [ origin ] === id ) {
169
+ connectedIds . add ( edge [ dir ] )
170
+ }
171
+ } )
172
+
173
+ return nodes . filter ( ( n ) => connectedIds . has ( n . id ) )
165
174
}
166
- export function getOutgoers < T extends Elements = FlowElements > ( nodeOrId : Node | { id : string } | string , elements : T ) {
167
- return getConnectedElements ( nodeOrId , elements , 'target' )
175
+
176
+ export function getOutgoers < N extends Node > ( nodeOrId : Node | { id : string } | string , nodes : N [ ] , edges : Edge [ ] ) : N [ ]
177
+ export function getOutgoers < T extends Elements > (
178
+ nodeOrId : Node | { id : string } | string ,
179
+ elements : T ,
180
+ ) : T extends FlowElements ? GraphNode [ ] : Node [ ]
181
+ export function getOutgoers ( ...args : any [ ] ) {
182
+ if ( args . length === 3 ) {
183
+ const [ nodeOrId , nodes , edges ] = args
184
+ return getConnectedElements ( nodeOrId , nodes , edges , 'target' )
185
+ }
186
+
187
+ const [ nodeOrId , elements ] = args
188
+ const node : Node = isString ( nodeOrId ) ? { id : nodeOrId } : nodeOrId
189
+
190
+ const outgoers = elements . filter ( ( el : Element ) => isEdge ( el ) && el . source === node . id )
191
+
192
+ return outgoers . map ( ( edge : Edge ) => elements . find ( ( el : Element ) => isNode ( el ) && el . id === edge . target ) )
168
193
}
169
194
170
- export function getIncomers < T extends Elements = FlowElements > ( nodeOrId : Node | { id : string } | string , elements : T ) {
171
- return getConnectedElements ( nodeOrId , elements , 'source' )
195
+ export function getIncomers < N extends Node > ( nodeOrId : Node | { id : string } | string , nodes : N [ ] , edges : Edge [ ] ) : N [ ]
196
+ export function getIncomers < T extends Elements > (
197
+ nodeOrId : Node | { id : string } | string ,
198
+ elements : T ,
199
+ ) : T extends FlowElements ? GraphNode [ ] : Node [ ]
200
+ export function getIncomers ( ...args : any [ ] ) {
201
+ if ( args . length === 3 ) {
202
+ const [ nodeOrId , nodes , edges ] = args
203
+ return getConnectedElements ( nodeOrId , nodes , edges , 'source' )
204
+ }
205
+
206
+ const [ nodeOrId , elements ] = args
207
+ const node : Node = isString ( nodeOrId ) ? { id : nodeOrId } : nodeOrId
208
+
209
+ const incomers = elements . filter ( ( el : Element ) => isEdge ( el ) && el . target === node . id )
210
+
211
+ return incomers . map ( ( edge : Edge ) => elements . find ( ( el : Element ) => isNode ( el ) && el . id === edge . source ) )
172
212
}
173
213
174
214
export function getEdgeId ( { source, sourceHandle, target, targetHandle } : Connection ) {
@@ -364,26 +404,34 @@ export function getNodesInside(
364
404
} )
365
405
}
366
406
367
- export function getConnectedEdges < N extends Node | { id : string } | string , E extends Edge > ( nodes : N [ ] , edges : E [ ] ) {
368
- const nodeIds = nodes . map ( ( node ) => ( isString ( node ) ? node : node . id ) )
407
+ export function getConnectedEdges < E extends Edge > ( nodesOrId : Node [ ] | string , edges : E [ ] ) {
408
+ const nodeIds = new Set ( )
369
409
370
- return edges . filter ( ( edge ) => nodeIds . includes ( edge . source ) || nodeIds . includes ( edge . target ) )
410
+ if ( isString ( nodesOrId ) ) {
411
+ nodeIds . add ( nodesOrId )
412
+ } else if ( nodesOrId . length >= 1 ) {
413
+ nodesOrId . forEach ( ( n ) => nodeIds . add ( n . id ) )
414
+ }
415
+
416
+ return edges . filter ( ( edge ) => nodeIds . has ( edge . source ) || nodeIds . has ( edge . target ) )
371
417
}
372
418
373
- export function getConnectedNodes < N extends Node | { id : string } | string , E extends Edge > ( nodes : N [ ] , edges : E [ ] ) {
374
- const nodeIds = nodes . map ( ( node ) => ( isString ( node ) ? node : node . id ) )
419
+ export function getConnectedNodes < N extends Node | { id : string } | string > ( nodes : N [ ] , edges : Edge [ ] ) {
420
+ const nodeIds = new Set ( )
421
+
422
+ nodes . forEach ( ( node ) => nodeIds . add ( isString ( node ) ? node : node . id ) )
375
423
376
424
const connectedNodeIds = edges . reduce ( ( acc , edge ) => {
377
- if ( nodeIds . includes ( edge . source ) ) {
425
+ if ( nodeIds . has ( edge . source ) ) {
378
426
acc . add ( edge . target )
379
427
}
380
428
381
- if ( nodeIds . includes ( edge . target ) ) {
429
+ if ( nodeIds . has ( edge . target ) ) {
382
430
acc . add ( edge . source )
383
431
}
384
432
385
433
return acc
386
- } , new Set < string > ( ) )
434
+ } , new Set ( ) )
387
435
388
436
return nodes . filter ( ( node ) => connectedNodeIds . has ( isString ( node ) ? node : node . id ) )
389
437
}
0 commit comments