Skip to content

Commit 57478b8

Browse files
committed
refactor(tree-view): improve virtualization dx
1 parent a5aa6f4 commit 57478b8

File tree

8 files changed

+321
-15
lines changed

8 files changed

+321
-15
lines changed

.changeset/cyan-wolves-travel.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
"@zag-js/tabs": patch
44
---
55

6-
bugfix: `tabs` and `radio-group` machine no longer show the `indicator` prematurely
6+
Fix issue where `tabs` and `radio-group` machine no longer show the `indicator` prematurely
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
---
2+
"@zag-js/tree-view": minor
3+
---
4+
5+
- Added `scrollToIndexFn` prop to enable keyboard navigation in virtualized trees
6+
- **[Breaking]:** `getVisibleNodes()` now returns `{ node, indexPath }[]` instead of `node[]`. Returning the index path
7+
perhaps the most useful use of this function, hence the change.
8+
9+
**Migration:**
10+
11+
```tsx
12+
// Before
13+
const nodes = api.getVisibleNodes()
14+
nodes.forEach((node) => console.log(node.id))
15+
16+
// After
17+
const visibleNodes = api.getVisibleNodes()
18+
visibleNodes.forEach(({ node }) => console.log(node.id))
19+
```
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import { useVirtualizer, type Virtualizer } from "@tanstack/react-virtual"
2+
import { normalizeProps, useMachine } from "@zag-js/react"
3+
import { treeviewControls } from "@zag-js/shared"
4+
import * as tree from "@zag-js/tree-view"
5+
import { ChevronRightIcon, FileIcon, FolderIcon } from "lucide-react"
6+
import { useId, useRef } from "react"
7+
import { StateVisualizer } from "../components/state-visualizer"
8+
import { Toolbar } from "../components/toolbar"
9+
import { useControls } from "../hooks/use-controls"
10+
11+
interface Node {
12+
id: string
13+
name: string
14+
children?: Node[]
15+
}
16+
17+
// Generate a large tree for virtualization demo
18+
function generateLargeTree(): Node {
19+
const folders: Node[] = []
20+
21+
for (let i = 0; i < 50; i++) {
22+
const children: Node[] = []
23+
for (let j = 0; j < 20; j++) {
24+
children.push({
25+
id: `folder-${i}/file-${j}.ts`,
26+
name: `file-${j}.ts`,
27+
})
28+
}
29+
folders.push({
30+
id: `folder-${i}`,
31+
name: `folder-${i}`,
32+
children,
33+
})
34+
}
35+
36+
return {
37+
id: "ROOT",
38+
name: "",
39+
children: folders,
40+
}
41+
}
42+
43+
const collection = tree.collection<Node>({
44+
nodeToValue: (node) => node.id,
45+
nodeToString: (node) => node.name,
46+
rootNode: generateLargeTree(),
47+
})
48+
49+
const ROW_HEIGHT = 32
50+
51+
interface TreeNodeProps {
52+
node: Node
53+
indexPath: number[]
54+
api: tree.Api
55+
}
56+
57+
const TreeNode = (props: TreeNodeProps) => {
58+
const { node, indexPath, api } = props
59+
60+
const nodeProps = { indexPath, node }
61+
const nodeState = api.getNodeState(nodeProps)
62+
63+
// Calculate indentation based on depth
64+
const indent = nodeState.depth * 20
65+
66+
if (nodeState.isBranch) {
67+
return (
68+
<div
69+
{...api.getBranchControlProps(nodeProps)}
70+
style={{
71+
paddingLeft: indent,
72+
height: ROW_HEIGHT,
73+
display: "flex",
74+
alignItems: "center",
75+
gap: "4px",
76+
}}
77+
>
78+
<span {...api.getBranchIndicatorProps(nodeProps)}>
79+
<ChevronRightIcon
80+
size={16}
81+
style={{
82+
transform: nodeState.expanded ? "rotate(90deg)" : "rotate(0deg)",
83+
transition: "transform 0.15s",
84+
}}
85+
/>
86+
</span>
87+
<FolderIcon size={16} />
88+
<span {...api.getBranchTextProps(nodeProps)}>{node.name}</span>
89+
</div>
90+
)
91+
}
92+
93+
return (
94+
<div
95+
{...api.getItemProps(nodeProps)}
96+
style={{
97+
paddingLeft: indent,
98+
height: ROW_HEIGHT,
99+
display: "flex",
100+
alignItems: "center",
101+
gap: "4px",
102+
}}
103+
>
104+
<FileIcon size={16} />
105+
<span {...api.getItemTextProps(nodeProps)}>{node.name}</span>
106+
</div>
107+
)
108+
}
109+
110+
export default function Page() {
111+
const controls = useControls(treeviewControls)
112+
const parentRef = useRef<HTMLDivElement>(null)
113+
const virtualizerRef = useRef<Virtualizer<HTMLDivElement, Element>>(null)
114+
115+
const service = useMachine(tree.machine, {
116+
id: useId(),
117+
collection,
118+
...controls.context,
119+
scrollToIndexFn(details) {
120+
virtualizerRef.current?.scrollToIndex(details.index, { align: "auto" })
121+
},
122+
})
123+
124+
const api = tree.connect(service, normalizeProps)
125+
126+
// Get visible nodes (now returns { node, indexPath }[])
127+
const visibleNodes = api.getVisibleNodes()
128+
129+
const virtualizer = useVirtualizer({
130+
count: visibleNodes.length,
131+
getScrollElement: () => parentRef.current,
132+
estimateSize: () => ROW_HEIGHT,
133+
overscan: 10,
134+
})
135+
136+
// Keep ref updated for scrollToIndexFn
137+
virtualizerRef.current = virtualizer
138+
139+
return (
140+
<>
141+
<main className="tree-view">
142+
<div {...api.getRootProps()}>
143+
<h3 {...api.getLabelProps()}>Virtualized Tree ({visibleNodes.length} visible nodes)</h3>
144+
<div style={{ display: "flex", gap: "10px", marginBottom: "10px" }}>
145+
<button onClick={() => api.collapse()}>Collapse All</button>
146+
<button onClick={() => api.expand()}>Expand All</button>
147+
{controls.context.selectionMode === "multiple" && (
148+
<>
149+
<button onClick={() => api.select()}>Select All</button>
150+
<button onClick={() => api.deselect()}>Deselect All</button>
151+
</>
152+
)}
153+
</div>
154+
155+
{/* Scrollable container */}
156+
<div
157+
ref={parentRef}
158+
{...api.getTreeProps()}
159+
style={{
160+
height: 400,
161+
overflow: "auto",
162+
border: "1px solid #ccc",
163+
borderRadius: "4px",
164+
}}
165+
>
166+
{/* Total size container */}
167+
<div
168+
style={{
169+
height: `${virtualizer.getTotalSize()}px`,
170+
width: "100%",
171+
position: "relative",
172+
}}
173+
>
174+
{/* Only render visible items */}
175+
{virtualizer.getVirtualItems().map((virtualItem) => {
176+
const { node, indexPath } = visibleNodes[virtualItem.index]
177+
178+
return (
179+
<div
180+
key={node.id}
181+
data-index={virtualItem.index}
182+
style={{
183+
position: "absolute",
184+
top: 0,
185+
left: 0,
186+
width: "100%",
187+
height: `${virtualItem.size}px`,
188+
transform: `translateY(${virtualItem.start}px)`,
189+
}}
190+
>
191+
<TreeNode node={node} indexPath={indexPath} api={api} />
192+
</div>
193+
)
194+
})}
195+
</div>
196+
</div>
197+
</div>
198+
</main>
199+
200+
<Toolbar controls={controls.ui}>
201+
<StateVisualizer state={service} omit={["collection"]} />
202+
</Toolbar>
203+
</>
204+
)
205+
}

packages/machines/tree-view/src/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ export type {
2222
NodeProps,
2323
NodeState,
2424
NodeWithError,
25+
ScrollToIndexDetails,
2526
SelectionChangeDetails,
2627
TreeLoadingStatus,
2728
TreeLoadingStatusMap,
2829
TreeNode,
30+
VisibleNode,
2931
} from "./tree-view.types"

packages/machines/tree-view/src/tree-view.connect.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ export function connect<T extends PropTypes, V extends TreeNode = TreeNode>(
8787
send({ type: value ? "NODE.SELECT" : "SELECTED.ALL", value, isTrusted: false })
8888
},
8989
getVisibleNodes() {
90-
return computed("visibleNodes").map(({ node }) => node)
90+
return computed("visibleNodes")
9191
},
9292
focus(value) {
9393
dom.focusNode(scope, value)

packages/machines/tree-view/src/tree-view.machine.ts

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import type { TreeNode } from "@zag-js/collection"
2+
import type { Params } from "@zag-js/core"
23
import { createGuards, createMachine } from "@zag-js/core"
3-
import { getByTypeahead, setElementValue } from "@zag-js/dom-query"
4+
import { getByTypeahead, raf, setElementValue } from "@zag-js/dom-query"
45
import { addOrRemove, diff, first, isArray, isEqual, last, remove, toArray, uniq } from "@zag-js/utils"
56
import { collection } from "./tree-view.collection"
67
import * as dom from "./tree-view.dom"
@@ -358,47 +359,63 @@ export const machine = createMachine<TreeViewSchema>({
358359
clearSelected({ context }) {
359360
context.set("selectedValue", [])
360361
},
361-
focusTreeFirstNode({ prop, scope }) {
362+
focusTreeFirstNode(params) {
363+
const { prop, scope } = params
362364
const collection = prop("collection")
363365
const firstNode = collection.getFirstNode()
364366
const firstValue = collection.getNodeValue(firstNode)
365-
dom.focusNode(scope, firstValue)
367+
const scrolled = scrollToNode(params, firstValue)
368+
if (scrolled) raf(() => dom.focusNode(scope, firstValue))
369+
else dom.focusNode(scope, firstValue)
366370
},
367371
focusTreeLastNode(params) {
368372
const { prop, scope } = params
369373
const collection = prop("collection")
370374
const lastNode = collection.getLastNode(undefined, { skip: skipFn(params) })
371375
const lastValue = collection.getNodeValue(lastNode)
372-
dom.focusNode(scope, lastValue)
376+
const scrolled = scrollToNode(params, lastValue)
377+
if (scrolled) raf(() => dom.focusNode(scope, lastValue))
378+
else dom.focusNode(scope, lastValue)
373379
},
374-
focusBranchFirstNode({ event, prop, scope }) {
380+
focusBranchFirstNode(params) {
381+
const { event, prop, scope } = params
375382
const collection = prop("collection")
376383
const branchNode = collection.findNode(event.id)
377384
const firstNode = collection.getFirstNode(branchNode)
378385
const firstValue = collection.getNodeValue(firstNode)
379-
dom.focusNode(scope, firstValue)
386+
const scrolled = scrollToNode(params, firstValue)
387+
if (scrolled) raf(() => dom.focusNode(scope, firstValue))
388+
else dom.focusNode(scope, firstValue)
380389
},
381390
focusTreeNextNode(params) {
382391
const { event, prop, scope } = params
383392
const collection = prop("collection")
384393
const nextNode = collection.getNextNode(event.id, { skip: skipFn(params) })
385394
if (!nextNode) return
386395
const nextValue = collection.getNodeValue(nextNode)
387-
dom.focusNode(scope, nextValue)
396+
const scrolled = scrollToNode(params, nextValue)
397+
if (scrolled) raf(() => dom.focusNode(scope, nextValue))
398+
else dom.focusNode(scope, nextValue)
388399
},
389400
focusTreePrevNode(params) {
390401
const { event, prop, scope } = params
391402
const collection = prop("collection")
392403
const prevNode = collection.getPreviousNode(event.id, { skip: skipFn(params) })
393404
if (!prevNode) return
394405
const prevValue = collection.getNodeValue(prevNode)
395-
dom.focusNode(scope, prevValue)
406+
const scrolled = scrollToNode(params, prevValue)
407+
if (scrolled) raf(() => dom.focusNode(scope, prevValue))
408+
else dom.focusNode(scope, prevValue)
396409
},
397-
focusBranchNode({ event, prop, scope }) {
410+
focusBranchNode(params) {
411+
const { event, prop, scope } = params
398412
const collection = prop("collection")
399413
const parentNode = collection.getParentNode(event.id)
400414
const parentValue = parentNode ? collection.getNodeValue(parentNode) : undefined
401-
dom.focusNode(scope, parentValue)
415+
if (!parentValue) return
416+
const scrolled = scrollToNode(params, parentValue)
417+
if (scrolled) raf(() => dom.focusNode(scope, parentValue))
418+
else dom.focusNode(scope, parentValue)
402419
},
403420
selectAllNodes({ context, prop }) {
404421
context.set("selectedValue", prop("collection").getValues())
@@ -416,7 +433,10 @@ export const machine = createMachine<TreeViewSchema>({
416433
key: event.key,
417434
})
418435

419-
dom.focusNode(scope, node?.id)
436+
if (!node?.id) return
437+
const scrolled = scrollToNode(params, node.id)
438+
if (scrolled) raf(() => dom.focusNode(scope, node.id))
439+
else dom.focusNode(scope, node.id)
420440
},
421441
toggleNodeSelection({ context, event }) {
422442
const selectedValue = addOrRemove(context.get("selectedValue"), event.id)
@@ -641,3 +661,26 @@ export const machine = createMachine<TreeViewSchema>({
641661
},
642662
},
643663
})
664+
665+
function scrollToNode(params: Params<TreeViewSchema>, value: string): boolean {
666+
const { prop, scope, computed } = params
667+
const scrollToIndexFn = prop("scrollToIndexFn")
668+
if (!scrollToIndexFn) return false
669+
670+
const collection = prop("collection")
671+
const visibleNodes = computed("visibleNodes")
672+
673+
for (let i = 0; i < visibleNodes.length; i++) {
674+
const { node, indexPath } = visibleNodes[i]
675+
if (collection.getNodeValue(node) !== value) continue
676+
scrollToIndexFn({
677+
index: i,
678+
node,
679+
indexPath,
680+
getElement: () => scope.getById(dom.getNodeId(scope, value)),
681+
})
682+
return true
683+
}
684+
685+
return false
686+
}

packages/machines/tree-view/src/tree-view.props.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export const props = createProps<TreeViewProps>()([
3030
"onRenameStart",
3131
"onBeforeRename",
3232
"onRenameComplete",
33+
"scrollToIndexFn",
3334
])
3435

3536
export const splitProps = createSplitProps<Partial<TreeViewProps>>(props)

0 commit comments

Comments
 (0)