@@ -7,7 +7,7 @@ import type {RankDir} from "../layout";
77interface UseFocusOptions {
88 nodes : Node < NodeData > [ ] ;
99 edges : Edge < EdgeData > [ ] ;
10- activeNodeId : string | null ;
10+ activeNodeIds : string [ ] ;
1111 rankDir ?: RankDir ;
1212 initialMode ?: ViewMode ;
1313}
@@ -37,12 +37,12 @@ function getNeighbourIds(nodeId: string, nodes: Node<NodeData>[], edges: Edge<Ed
3737 return [ before ?. id , after ?. id ] . filter ( ( id ) : id is string => id !== undefined ) . map ( ( id ) => ( { id} ) ) ;
3838}
3939
40- export function useFocus ( { nodes, edges, activeNodeId , rankDir = "TB" , initialMode = "auto" } : UseFocusOptions ) {
40+ export function useFocus ( { nodes, edges, activeNodeIds , rankDir = "TB" , initialMode = "auto" } : UseFocusOptions ) {
4141 const { fitView} = useReactFlow ( ) ;
4242 const [ mode , setMode ] = useState < "auto" | "manual" > ( initialMode ) ;
4343 const prevMode = useRef < "auto" | "manual" > ( mode ) ;
4444 const initialDone = useRef ( false ) ;
45- const prevFocusId = useRef < string | null > ( null ) ;
45+ const prevFocusKey = useRef < string > ( "" ) ;
4646
4747 const isManual = useMemo ( ( ) => mode === "manual" , [ mode ] ) ;
4848 const isHorizontal = useMemo ( ( ) => [ "LR" , "RL" ] . includes ( rankDir ) , [ rankDir ] ) ;
@@ -72,28 +72,34 @@ export function useFocus({nodes, edges, activeNodeId, rankDir = "TB", initialMod
7272 duration : 0 ,
7373 } ) . then ( ) ;
7474 }
75- prevFocusId . current = null ;
75+ prevFocusKey . current = "" ;
7676 prevMode . current = mode ;
7777 return ;
7878 }
7979
80- if ( mode === "auto" && prevMode . current !== "auto" ) prevFocusId . current = null ;
80+ if ( mode === "auto" && prevMode . current !== "auto" ) prevFocusKey . current = "" ;
8181 prevMode . current = mode ;
8282
8383 if ( mode !== "auto" ) return ;
8484
85- if ( activeNodeId && activeNodeId !== prevFocusId . current ) {
86- prevFocusId . current = activeNodeId ;
87-
88- const activeNode = nodes . find ( ( n ) => n . id === activeNodeId ) ;
89- if ( activeNode ?. data . nodeType === "node" ) {
85+ const focusKey = [ ...activeNodeIds ] . sort ( ) . join ( "," ) ;
86+ if ( activeNodeIds . length > 0 && focusKey !== prevFocusKey . current ) {
87+ prevFocusKey . current = focusKey ;
88+
89+ const activeNodes = nodes . filter (
90+ ( n ) => activeNodeIds . includes ( n . id ) && n . data . nodeType === "node" ,
91+ ) ;
92+ if ( activeNodes . length > 0 ) {
93+ const neighbours = activeNodes . flatMap ( ( n ) =>
94+ getNeighbourIds ( n . id , nodes , edges , isHorizontal ) ,
95+ ) ;
9096 fitView ( {
91- nodes : [ { id : activeNodeId } , ...getNeighbourIds ( activeNodeId , nodes , edges , isHorizontal ) ] ,
97+ nodes : [ ... activeNodes . map ( ( n ) => ( { id : n . id } ) ) , ...neighbours ] ,
9298 duration : FIT_VIEW_DURATION ,
9399 } ) . then ( ) ;
94100 }
95101 }
96- } , [ nodes , edges , activeNodeId , fitView , mode , isHorizontal ] ) ;
102+ } , [ nodes , edges , activeNodeIds , fitView , mode , isHorizontal ] ) ;
97103
98104 return { isManual, goAuto, goManual, fitContent} ;
99105}
0 commit comments