Skip to content

Commit 6ad8771

Browse files
Fix graph traversing for multi-branch DAGs (GH-8)
resolves #6
2 parents 8a1f678 + 4274635 commit 6ad8771

File tree

6 files changed

+150
-44
lines changed

6 files changed

+150
-44
lines changed

examples/fanout_agent.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import asyncio
2+
import operator
3+
import random
4+
from typing import Annotated, TypedDict
5+
6+
from langgraph.graph import END, START, StateGraph
7+
from langgraph.types import Send
8+
9+
from langgraphics import watch
10+
11+
RANDOM_FANOUT = False
12+
FANOUT_NODES = ["fanout_a", "fanout_b", "fanout_c"]
13+
FANOUT_CHANCE = 3 / 4
14+
15+
16+
class GraphState(TypedDict):
17+
initial_result: str
18+
fanout_results: Annotated[list[str], operator.add]
19+
final_message: str
20+
21+
22+
async def initial_node(state: GraphState) -> dict:
23+
await asyncio.sleep(2)
24+
return {"initial_result": "initial done"}
25+
26+
27+
async def fanout_a(state: GraphState) -> dict:
28+
await asyncio.sleep(2)
29+
return {"fanout_results": ["fanout_a done"]}
30+
31+
32+
async def fanout_b(state: GraphState) -> dict:
33+
await asyncio.sleep(2)
34+
return {"fanout_results": ["fanout_b done"]}
35+
36+
37+
async def fanout_c(state: GraphState) -> dict:
38+
await asyncio.sleep(2)
39+
return {"fanout_results": ["fanout_c done"]}
40+
41+
42+
async def final_node(state: GraphState) -> dict:
43+
await asyncio.sleep(2)
44+
return {"final_message": "\n".join(state["fanout_results"])}
45+
46+
47+
def route_fanout(state: GraphState) -> list[Send]:
48+
chosen = [n for n in FANOUT_NODES if random.random() < FANOUT_CHANCE] or [random.choice(FANOUT_NODES)]
49+
return [Send(name, state) for name in chosen]
50+
51+
52+
builder = StateGraph(GraphState)
53+
54+
builder.add_node("initial", initial_node)
55+
builder.add_node("fanout_a", fanout_a)
56+
builder.add_node("fanout_b", fanout_b)
57+
builder.add_node("fanout_c", fanout_c)
58+
builder.add_node("final", final_node)
59+
60+
builder.add_edge(START, "initial")
61+
62+
if RANDOM_FANOUT:
63+
builder.add_conditional_edges("initial", route_fanout, {n: n for n in FANOUT_NODES})
64+
else:
65+
builder.add_edge("initial", "fanout_a")
66+
builder.add_edge("initial", "fanout_b")
67+
builder.add_edge("initial", "fanout_c")
68+
69+
builder.add_edge("fanout_a", "final")
70+
builder.add_edge("fanout_b", "final")
71+
builder.add_edge("fanout_c", "final")
72+
builder.add_edge("final", END)
73+
74+
graph = builder.compile()
75+
graph = watch(graph)
76+
77+
78+
async def main() -> None:
79+
await graph.ainvoke({"initial_result": "", "fanout_results": [], "final_message": ""})
80+
81+
82+
if __name__ == "__main__":
83+
asyncio.run(main())

langgraphics-web/src/components/GraphCanvas.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ interface GraphCanvasProps {
1212
nodes: Node<NodeData>[];
1313
edges: Edge<EdgeData>[];
1414
events: ExecutionEvent[];
15-
activeNodeId: string | null;
15+
activeNodeIds: string[];
1616
inspect: ReactNode;
1717
initialMode: ViewMode;
1818
initialRankDir?: RankDir;
@@ -21,11 +21,11 @@ interface GraphCanvasProps {
2121
onRankDirChange?: (v: RankDir) => void;
2222
}
2323

24-
export function GraphCanvas({nodes, edges, events, activeNodeId, inspect, initialMode = "auto", initialInspect = "off", initialColorMode = "system", initialRankDir = "TB", onRankDirChange}: GraphCanvasProps) {
24+
export function GraphCanvas({nodes, edges, events, activeNodeIds, inspect, initialMode = "auto", initialInspect = "off", initialColorMode = "system", initialRankDir = "TB", onRankDirChange}: GraphCanvasProps) {
2525
const [rankDir, setRankDir] = useState<RankDir>(initialRankDir);
2626
const [colorMode, setColorMode] = useState<ColorMode>(initialColorMode);
2727
const [inspectorMode, setInspectorMode] = useState<InspectorMode>(initialInspect);
28-
const {isManual, goAuto, goManual, fitContent} = useFocus({nodes, edges, activeNodeId, rankDir, initialMode});
28+
const {isManual, goAuto, goManual, fitContent} = useFocus({nodes, edges, activeNodeIds, rankDir, initialMode});
2929

3030
const handleRankDirChange = useCallback(async (v: RankDir) => {
3131
setRankDir(v);

langgraphics-web/src/hooks/useFocus.ts

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import type {RankDir} from "../layout";
77
interface UseFocusOptions {
88
nodes: Node<NodeData>[];
99
edges: Edge<EdgeData>[];
10-
activeNodeId: string | null;
10+
activeNodeIds: string[];
1111
rankDir?: RankDir;
1212
initialMode?: ViewMode;
1313
}
@@ -37,12 +37,12 @@ function getNeighbourIds(nodeId: string, nodes: Node<NodeData>[], edges: Edge<Ed
3737
return [before?.id, after?.id].filter((id): id is string => id !== undefined).map((id) => ({id}));
3838
}
3939

40-
export function useFocus({nodes, edges, activeNodeId, rankDir = "TB", initialMode = "auto"}: UseFocusOptions) {
40+
export function useFocus({nodes, edges, activeNodeIds, rankDir = "TB", initialMode = "auto"}: UseFocusOptions) {
4141
const {fitView} = useReactFlow();
4242
const [mode, setMode] = useState<"auto" | "manual">(initialMode);
4343
const prevMode = useRef<"auto" | "manual">(mode);
4444
const initialDone = useRef(false);
45-
const prevFocusId = useRef<string | null>(null);
45+
const prevFocusKey = useRef<string>("");
4646

4747
const isManual = useMemo(() => mode === "manual", [mode]);
4848
const isHorizontal = useMemo(() => ["LR", "RL"].includes(rankDir), [rankDir]);
@@ -72,28 +72,34 @@ export function useFocus({nodes, edges, activeNodeId, rankDir = "TB", initialMod
7272
duration: 0,
7373
}).then();
7474
}
75-
prevFocusId.current = null;
75+
prevFocusKey.current = "";
7676
prevMode.current = mode;
7777
return;
7878
}
7979

80-
if (mode === "auto" && prevMode.current !== "auto") prevFocusId.current = null;
80+
if (mode === "auto" && prevMode.current !== "auto") prevFocusKey.current = "";
8181
prevMode.current = mode;
8282

8383
if (mode !== "auto") return;
8484

85-
if (activeNodeId && activeNodeId !== prevFocusId.current) {
86-
prevFocusId.current = activeNodeId;
87-
88-
const activeNode = nodes.find((n) => n.id === activeNodeId);
89-
if (activeNode?.data.nodeType === "node") {
85+
const focusKey = [...activeNodeIds].sort().join(",");
86+
if (activeNodeIds.length > 0 && focusKey !== prevFocusKey.current) {
87+
prevFocusKey.current = focusKey;
88+
89+
const activeNodes = nodes.filter(
90+
(n) => activeNodeIds.includes(n.id) && n.data.nodeType === "node",
91+
);
92+
if (activeNodes.length > 0) {
93+
const neighbours = activeNodes.flatMap((n) =>
94+
getNeighbourIds(n.id, nodes, edges, isHorizontal),
95+
);
9096
fitView({
91-
nodes: [{id: activeNodeId}, ...getNeighbourIds(activeNodeId, nodes, edges, isHorizontal)],
97+
nodes: [...activeNodes.map((n) => ({id: n.id})), ...neighbours],
9298
duration: FIT_VIEW_DURATION,
9399
}).then();
94100
}
95101
}
96-
}, [nodes, edges, activeNodeId, fitView, mode, isHorizontal]);
102+
}, [nodes, edges, activeNodeIds, fitView, mode, isHorizontal]);
97103

98104
return {isManual, goAuto, goManual, fitContent};
99105
}

langgraphics-web/src/hooks/useGraphState.ts

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,27 @@ export function computeStatuses(events: ExecutionEvent[]): {
99
} {
1010
const nodeStatuses = new Map<string, NodeStatus>();
1111
const edgeStatuses = new Map<string, EdgeStatus>();
12+
const edgeInfo = new Map<string, {source: string; target: string}>();
1213

1314
for (const event of events) {
1415
if (event.type === "run_start") {
1516
nodeStatuses.clear();
1617
edgeStatuses.clear();
18+
edgeInfo.clear();
1719
} else if (event.type === "edge_active") {
18-
for (const [id, status] of edgeStatuses) {
19-
if (status === "active") edgeStatuses.set(id, "traversed");
20+
edgeInfo.set(event.edge_id, {source: event.source, target: event.target});
21+
if (nodeStatuses.get(event.source) === "active") {
22+
nodeStatuses.set(event.source, "completed");
2023
}
21-
for (const [id, status] of nodeStatuses) {
22-
if (status === "active") nodeStatuses.set(id, "completed");
24+
for (const [id, info] of edgeInfo) {
25+
if (info.target === event.source && edgeStatuses.get(id) === "active") {
26+
edgeStatuses.set(id, "traversed");
27+
}
2328
}
2429
edgeStatuses.set(event.edge_id, "active");
25-
nodeStatuses.set(event.target, "active");
30+
if (nodeStatuses.get(event.target) !== "error") {
31+
nodeStatuses.set(event.target, "active");
32+
}
2633
} else if (event.type === "error") {
2734
for (const [id, status] of edgeStatuses) {
2835
if (status === "active") edgeStatuses.set(id, "traversed");
@@ -52,14 +59,14 @@ export function useGraphState(topology: GraphMessage | null, events: ExecutionEv
5259
}, [topology, rankDir]);
5360

5461
return useMemo(() => {
55-
if (events.length === 0) return {nodes: base.nodes, edges: base.edges, activeNodeId: null as string | null};
62+
if (events.length === 0) return {nodes: base.nodes, edges: base.edges, activeNodeIds: [] as string[]};
5663

5764
const {nodeStatuses, edgeStatuses} = computeStatuses(events);
5865

59-
let activeNodeId: string | null = null;
66+
const activeNodeIds: string[] = [];
6067
const nodes = base.nodes.map((node) => {
6168
const status = nodeStatuses.get(node.id);
62-
if (status === "active") activeNodeId = node.id;
69+
if (status === "active") activeNodeIds.push(node.id);
6370
return {...node, className: status};
6471
});
6572

@@ -81,6 +88,6 @@ export function useGraphState(topology: GraphMessage | null, events: ExecutionEv
8188
return {...edge, className, animated: status === "active", markerEnd};
8289
});
8390

84-
return {nodes, edges, activeNodeId};
91+
return {nodes, edges, activeNodeIds};
8592
}, [base, events]);
8693
}

langgraphics-web/src/main.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ const {theme, mode, inspect, direction} = parseParams();
3232
function Index() {
3333
const [rankDir, setRankDir] = useState<RankDir>(direction);
3434
const {topology, events, nodeEntries} = useWebSocket(WS_URL);
35-
const {nodes, edges, activeNodeId} = useGraphState(topology, events, rankDir);
35+
const {nodes, edges, activeNodeIds} = useGraphState(topology, events, rankDir);
3636

3737
return (
3838
<ReactFlowProvider>
@@ -44,7 +44,7 @@ function Index() {
4444
initialInspect={inspect}
4545
initialColorMode={theme}
4646
initialRankDir={direction}
47-
activeNodeId={activeNodeId}
47+
activeNodeIds={activeNodeIds}
4848
onRankDirChange={setRankDir}
4949
inspect={<InspectPanel nodeEntries={nodeEntries}/>}
5050
/>

langgraphics/streamer.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,14 @@ def __init__(
9696
self.edge_lookup = edge_lookup
9797
self.http_server = http_server
9898
self.node_names: set[str] = set()
99+
self.predecessors: dict[str, set[str]] = {}
99100
for src, tgt in edge_lookup:
100101
self.node_names.add(src)
101102
self.node_names.add(tgt)
103+
self.predecessors.setdefault(tgt, set()).add(src)
102104
self.node_names -= {"__start__", "__end__"}
105+
self.generation: dict[str, int] = {"__start__": 0}
106+
self.linked: set[tuple[str, int, str]] = set()
103107

104108
def __getattr__(self, name: str) -> Any:
105109
return getattr(self.graph, name)
@@ -118,17 +122,23 @@ async def broadcast(self, message: dict[str, Any]) -> None:
118122
except Exception:
119123
pass
120124

121-
async def _emit_edge(self, source: str, target: str) -> None:
122-
edge_id = self.edge_lookup.get((source, target))
123-
if edge_id:
124-
await self.broadcast(
125-
{
126-
"type": "edge_active",
127-
"source": source,
128-
"target": target,
129-
"edge_id": edge_id,
130-
}
131-
)
125+
async def _emit_edge(self, target: str) -> None:
126+
for source in self.predecessors.get(target, set()):
127+
if src_gen := self.generation.get(source) is None:
128+
continue
129+
if (key := (source, src_gen, target)) in self.linked:
130+
continue
131+
self.linked.add(key)
132+
if edge_id := self.edge_lookup.get((source, target)):
133+
await self.broadcast(
134+
{
135+
"type": "edge_active",
136+
"source": source,
137+
"target": target,
138+
"edge_id": edge_id,
139+
}
140+
)
141+
self.generation[target] = self.generation.get(target, -1) + 1
132142

133143
async def _emit_error(self, last_node: str) -> None:
134144
for (src, tgt), eid in self.edge_lookup.items():
@@ -166,14 +176,14 @@ async def ainvoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any:
166176
input, config=merged_config, stream_mode="updates", **kwargs
167177
):
168178
if isinstance(chunk, dict):
169-
for node_name in chunk:
179+
for node_name, node_result in chunk.items():
170180
if node_name == "__metadata__":
171181
continue
172-
await self._emit_edge(last_node, node_name)
182+
await self._emit_edge(node_name)
173183
last_node = node_name
174-
result = chunk[node_name]
184+
result = node_result
175185

176-
await self._emit_edge(last_node, "__end__")
186+
await self._emit_edge("__end__")
177187
await self.broadcast({"type": "run_end", "run_id": run_id})
178188
except Exception:
179189
await self._emit_error(last_node)
@@ -204,12 +214,12 @@ async def astream(
204214
for node_name in chunk:
205215
if node_name == "__metadata__":
206216
continue
207-
await self._emit_edge(last_node, node_name)
217+
await self._emit_edge(node_name)
208218
last_node = node_name
209219
yield chunk
210220

211-
if last_node != "__start__":
212-
await self._emit_edge(last_node, "__end__")
221+
if len(self.generation) > 1:
222+
await self._emit_edge("__end__")
213223

214224
await self.broadcast({"type": "run_end", "run_id": run_id})
215225
except Exception:

0 commit comments

Comments
 (0)