@@ -2,19 +2,20 @@ import {useCallback, useEffect, useMemo, useRef, useState} from "react";
22import type { Edge , Node } from "@xyflow/react" ;
33import { useReactFlow } from "@xyflow/react" ;
44import type { EdgeData , NodeData } from "../types" ;
5- import { IS_HORIZONTAL } from "../layout" ;
5+ import type { RankDir } from "../layout" ;
66
77interface UseFocusOptions {
88 nodes : Node < NodeData > [ ] ;
99 edges : Edge < EdgeData > [ ] ;
1010 activeNodeId : string | null ;
11+ rankDir ?: RankDir ;
1112}
1213
1314const FIT_VIEW_DURATION = 1500 ;
1415
15- function getNeighbourIds ( nodeId : string , nodes : Node < NodeData > [ ] , edges : Edge < EdgeData > [ ] ) : any [ ] {
16+ function getNeighbourIds ( nodeId : string , nodes : Node < NodeData > [ ] , edges : Edge < EdgeData > [ ] , isHorizontal : boolean ) : any [ ] {
1617 const nodeRank = new Map < string , number > (
17- nodes . map ( ( n ) => [ n . id , IS_HORIZONTAL ? n . position . x : n . position . y ] ) ,
18+ nodes . map ( ( n ) => [ n . id , isHorizontal ? n . position . x : n . position . y ] ) ,
1819 ) ;
1920 const selfRank = nodeRank . get ( nodeId ) ?? 0 ;
2021
@@ -35,20 +36,25 @@ function getNeighbourIds(nodeId: string, nodes: Node<NodeData>[], edges: Edge<Ed
3536 return [ before ?. id , after ?. id ] . filter ( ( id ) : id is string => id !== undefined ) . map ( ( id ) => ( { id} ) ) ;
3637}
3738
38- export function useFocus ( { nodes, edges, activeNodeId} : UseFocusOptions ) {
39+ export function useFocus ( { nodes, edges, activeNodeId, rankDir = "TB" } : UseFocusOptions ) {
3940 const { fitView} = useReactFlow ( ) ;
4041 const [ mode , setMode ] = useState < "auto" | "manual" > ( "auto" ) ;
4142 const prevMode = useRef < "auto" | "manual" > ( mode ) ;
4243 const initialDone = useRef ( false ) ;
4344 const prevFocusId = useRef < string | null > ( null ) ;
4445
4546 const isManual = useMemo ( ( ) => mode === "manual" , [ mode ] ) ;
47+ const isHorizontal = useMemo ( ( ) => [ "LR" , "RL" ] . includes ( rankDir ) , [ rankDir ] ) ;
4648
47- const goAuto = useCallback ( async ( ) => {
48- setMode ( "auto" ) ;
49+ const fitContent = useCallback ( async ( ) => {
4950 await fitView ( { duration : FIT_VIEW_DURATION } ) ;
5051 } , [ fitView ] )
5152
53+ const goAuto = useCallback ( async ( ) => {
54+ setMode ( "auto" ) ;
55+ await fitContent ( ) ;
56+ } , [ fitContent ] )
57+
5258 const goManual = useCallback ( ( ) => {
5359 setMode ( "manual" ) ;
5460 } , [ ] )
@@ -61,7 +67,7 @@ export function useFocus({nodes, edges, activeNodeId}: UseFocusOptions) {
6167 const startNode = nodes . find ( ( n ) => n . data . nodeType === "start" ) ;
6268 if ( startNode ) {
6369 fitView ( {
64- nodes : [ startNode , ...getNeighbourIds ( startNode . id , nodes , edges ) ] ,
70+ nodes : [ startNode , ...getNeighbourIds ( startNode . id , nodes , edges , isHorizontal ) ] ,
6571 duration : 0 ,
6672 } ) . then ( ) ;
6773 }
@@ -83,12 +89,12 @@ export function useFocus({nodes, edges, activeNodeId}: UseFocusOptions) {
8389 fitView ( { duration : FIT_VIEW_DURATION } ) . then ( ) ;
8490 } else {
8591 fitView ( {
86- nodes : [ { id : activeNodeId } , ...getNeighbourIds ( activeNodeId , nodes , edges ) ] ,
92+ nodes : [ { id : activeNodeId } , ...getNeighbourIds ( activeNodeId , nodes , edges , isHorizontal ) ] ,
8793 duration : FIT_VIEW_DURATION ,
8894 } ) . then ( ) ;
8995 }
9096 }
91- } , [ nodes , edges , activeNodeId , fitView , mode ] ) ;
97+ } , [ nodes , edges , activeNodeId , fitView , mode , isHorizontal ] ) ;
9298
93- return { isManual, goAuto, goManual} ;
99+ return { isManual, goAuto, goManual, fitContent } ;
94100}
0 commit comments