Skip to content

Commit 19c6103

Browse files
Add scripts to convert ATIF and IST to crisp function trees for A/B
Signed-off-by: Anuradha Karuppiah <26330987+AnuradhaKaruppiah@users.noreply.github.com>
1 parent e8ed030 commit 19c6103

File tree

2 files changed

+339
-0
lines changed

2 files changed

+339
-0
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""Print a readable function ancestry tree from ATIF workflow output.
17+
18+
Example:
19+
python packages/nvidia_nat_eval/scripts/print_atif_function_tree.py \
20+
".tmp/nat/examples/evaluation_and_profiling/simple_web_query_eval/atif/workflow_output_atif.json"
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import argparse
26+
import json
27+
from collections import defaultdict
28+
from dataclasses import dataclass
29+
from pathlib import Path
30+
from typing import Any
31+
32+
33+
@dataclass
34+
class NodeStats:
35+
function_id: str
36+
function_name: str
37+
parent_id: str | None
38+
parent_name: str | None
39+
seen_in_step_ancestry: int = 0
40+
seen_in_tool_ancestry: int = 0
41+
42+
43+
def _load_json(path: Path) -> Any:
44+
return json.loads(path.read_text(encoding="utf-8"))
45+
46+
47+
def _iter_trajectories(payload: Any) -> list[tuple[str, dict[str, Any]]]:
48+
"""Normalize ATIF payload to (label, trajectory_dict)."""
49+
if isinstance(payload, list):
50+
out: list[tuple[str, dict[str, Any]]] = []
51+
for i, item in enumerate(payload):
52+
if not isinstance(item, dict):
53+
continue
54+
if isinstance(item.get("trajectory"), dict):
55+
label = f"item={item.get('item_id', i)}"
56+
out.append((label, item["trajectory"]))
57+
elif isinstance(item.get("steps"), list):
58+
out.append((f"trajectory={i}", item))
59+
return out
60+
61+
if isinstance(payload, dict):
62+
if isinstance(payload.get("trajectory"), dict):
63+
return [(f"item={payload.get('item_id', '0')}", payload["trajectory"])]
64+
if isinstance(payload.get("steps"), list):
65+
return [("trajectory=0", payload)]
66+
67+
raise ValueError("Unsupported ATIF JSON shape. Expected trajectory or eval-sample payload.")
68+
69+
70+
def _add_ancestry(nodes: dict[str, NodeStats], fn: dict[str, Any], from_tool: bool) -> None:
71+
function_id = str(fn.get("function_id") or "")
72+
function_name = str(fn.get("function_name") or "")
73+
parent_id = fn.get("parent_id")
74+
parent_name = fn.get("parent_name")
75+
if not function_id or not function_name:
76+
return
77+
78+
if function_id not in nodes:
79+
nodes[function_id] = NodeStats(
80+
function_id=function_id,
81+
function_name=function_name,
82+
parent_id=str(parent_id) if parent_id is not None else None,
83+
parent_name=str(parent_name) if parent_name is not None else None,
84+
)
85+
86+
if from_tool:
87+
nodes[function_id].seen_in_tool_ancestry += 1
88+
else:
89+
nodes[function_id].seen_in_step_ancestry += 1
90+
91+
92+
def _build_nodes(trajectory: dict[str, Any]) -> dict[str, NodeStats]:
93+
nodes: dict[str, NodeStats] = {}
94+
for step in trajectory.get("steps", []):
95+
extra = step.get("extra") or {}
96+
ancestry = extra.get("ancestry")
97+
if isinstance(ancestry, dict):
98+
_add_ancestry(nodes, ancestry.get("function_ancestry") or {}, from_tool=False)
99+
for tool_ancestry in extra.get("tool_ancestry") or []:
100+
if isinstance(tool_ancestry, dict):
101+
_add_ancestry(nodes, tool_ancestry.get("function_ancestry") or {}, from_tool=True)
102+
return nodes
103+
104+
105+
def _print_tree(nodes: dict[str, NodeStats]) -> None:
106+
by_parent: dict[str, list[str]] = defaultdict(list)
107+
for function_id, node in nodes.items():
108+
parent = node.parent_id or "root"
109+
if parent == function_id:
110+
# Defensive guard against malformed self-parent links.
111+
parent = "root"
112+
by_parent[parent].append(function_id)
113+
114+
for child_ids in by_parent.values():
115+
child_ids.sort(key=lambda fid: nodes[fid].function_name)
116+
117+
roots = [
118+
fid for fid, node in nodes.items()
119+
if (node.parent_id in (None, "", "root")) and fid != "root"
120+
]
121+
roots.sort(key=lambda fid: nodes[fid].function_name)
122+
123+
def rec(function_id: str, prefix: str, is_last: bool, visited: set[str]) -> None:
124+
if function_id in visited:
125+
branch = "└─ " if is_last else "├─ "
126+
print(f"{prefix}{branch}<cycle> [{function_id}]")
127+
return
128+
visited = set(visited)
129+
visited.add(function_id)
130+
131+
node = nodes[function_id]
132+
branch = "└─ " if is_last else "├─ "
133+
counts = []
134+
if node.seen_in_step_ancestry:
135+
counts.append(f"steps={node.seen_in_step_ancestry}")
136+
if node.seen_in_tool_ancestry:
137+
counts.append(f"tools={node.seen_in_tool_ancestry}")
138+
counts_str = f" ({', '.join(counts)})" if counts else ""
139+
print(f"{prefix}{branch}{node.function_name} [{node.function_id}]{counts_str}")
140+
141+
children = by_parent.get(function_id, [])
142+
child_prefix = prefix + (" " if is_last else "│ ")
143+
for i, child_id in enumerate(children):
144+
rec(child_id, child_prefix, i == len(children) - 1, visited)
145+
146+
print("root")
147+
if not roots and "root" in nodes:
148+
roots = ["root"]
149+
for i, root_id in enumerate(roots):
150+
rec(root_id, "", i == len(roots) - 1, set())
151+
152+
153+
def main() -> None:
154+
parser = argparse.ArgumentParser(description="Print ATIF function ancestry tree from workflow_output_atif.json")
155+
parser.add_argument("input_json", type=Path, help="Path to ATIF workflow output JSON")
156+
args = parser.parse_args()
157+
158+
payload = _load_json(args.input_json)
159+
trajectories = _iter_trajectories(payload)
160+
161+
for idx, (label, trajectory) in enumerate(trajectories):
162+
if idx > 0:
163+
print()
164+
session_id = trajectory.get("session_id", "unknown-session")
165+
print(f"=== {label} | mode=atif | session_id={session_id} ===")
166+
nodes = _build_nodes(trajectory)
167+
if not nodes:
168+
print("No ancestry metadata found in step.extra.")
169+
continue
170+
_print_tree(nodes)
171+
172+
173+
if __name__ == "__main__":
174+
main()
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""Print a readable function ancestry tree from legacy IST workflow output.
17+
18+
Example:
19+
python packages/nvidia_nat_eval/scripts/print_ist_function_tree.py \
20+
".tmp/nat/examples/evaluation_and_profiling/simple_web_query_eval/atif/workflow_output.json"
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import argparse
26+
import json
27+
from collections import defaultdict
28+
from dataclasses import dataclass
29+
from pathlib import Path
30+
from typing import Any
31+
32+
33+
@dataclass
34+
class NodeStats:
35+
function_id: str
36+
function_name: str
37+
parent_id: str | None
38+
parent_name: str | None
39+
seen_in_step_ancestry: int = 0
40+
seen_in_tool_ancestry: int = 0
41+
42+
43+
def _load_json(path: Path) -> Any:
44+
return json.loads(path.read_text(encoding="utf-8"))
45+
46+
47+
def _iter_items(payload: Any) -> list[tuple[str, dict[str, Any]]]:
48+
"""Normalize legacy payload to (label, item_dict)."""
49+
if isinstance(payload, list):
50+
out: list[tuple[str, dict[str, Any]]] = []
51+
for i, item in enumerate(payload):
52+
if isinstance(item, dict) and isinstance(item.get("intermediate_steps"), list):
53+
out.append((f"item={item.get('id', i)}", item))
54+
return out
55+
56+
if isinstance(payload, dict):
57+
if isinstance(payload.get("intermediate_steps"), list):
58+
return [(f"item={payload.get('id', '0')}", payload)]
59+
60+
raise ValueError("Unsupported legacy JSON shape. Expected item(s) with intermediate_steps.")
61+
62+
63+
def _add_ancestry(nodes: dict[str, NodeStats], fn: dict[str, Any], from_tool: bool) -> None:
64+
function_id = str(fn.get("function_id") or "")
65+
function_name = str(fn.get("function_name") or "")
66+
parent_id = fn.get("parent_id")
67+
parent_name = fn.get("parent_name")
68+
if not function_id or not function_name:
69+
return
70+
71+
if function_id not in nodes:
72+
nodes[function_id] = NodeStats(
73+
function_id=function_id,
74+
function_name=function_name,
75+
parent_id=str(parent_id) if parent_id is not None else None,
76+
parent_name=str(parent_name) if parent_name is not None else None,
77+
)
78+
79+
if from_tool:
80+
nodes[function_id].seen_in_tool_ancestry += 1
81+
else:
82+
nodes[function_id].seen_in_step_ancestry += 1
83+
84+
85+
def _build_nodes(item: dict[str, Any]) -> dict[str, NodeStats]:
86+
nodes: dict[str, NodeStats] = {}
87+
for step in item.get("intermediate_steps", []):
88+
fn = step.get("function_ancestry")
89+
if not isinstance(fn, dict):
90+
continue
91+
event_type = ((step.get("payload") or {}).get("event_type") or "")
92+
_add_ancestry(nodes, fn, from_tool=("TOOL" in str(event_type)))
93+
return nodes
94+
95+
96+
def _print_tree(nodes: dict[str, NodeStats]) -> None:
97+
by_parent: dict[str, list[str]] = defaultdict(list)
98+
for function_id, node in nodes.items():
99+
parent = node.parent_id or "root"
100+
if parent == function_id:
101+
# Defensive guard against malformed self-parent links.
102+
parent = "root"
103+
by_parent[parent].append(function_id)
104+
105+
for child_ids in by_parent.values():
106+
child_ids.sort(key=lambda fid: nodes[fid].function_name)
107+
108+
roots = [
109+
fid for fid, node in nodes.items()
110+
if (node.parent_id in (None, "", "root")) and fid != "root"
111+
]
112+
roots.sort(key=lambda fid: nodes[fid].function_name)
113+
114+
def rec(function_id: str, prefix: str, is_last: bool, visited: set[str]) -> None:
115+
if function_id in visited:
116+
branch = "└─ " if is_last else "├─ "
117+
print(f"{prefix}{branch}<cycle> [{function_id}]")
118+
return
119+
visited = set(visited)
120+
visited.add(function_id)
121+
122+
node = nodes[function_id]
123+
branch = "└─ " if is_last else "├─ "
124+
counts = []
125+
if node.seen_in_step_ancestry:
126+
counts.append(f"steps={node.seen_in_step_ancestry}")
127+
if node.seen_in_tool_ancestry:
128+
counts.append(f"tools={node.seen_in_tool_ancestry}")
129+
counts_str = f" ({', '.join(counts)})" if counts else ""
130+
print(f"{prefix}{branch}{node.function_name} [{node.function_id}]{counts_str}")
131+
132+
children = by_parent.get(function_id, [])
133+
child_prefix = prefix + (" " if is_last else "│ ")
134+
for i, child_id in enumerate(children):
135+
rec(child_id, child_prefix, i == len(children) - 1, visited)
136+
137+
print("root")
138+
if not roots and "root" in nodes:
139+
roots = ["root"]
140+
for i, root_id in enumerate(roots):
141+
rec(root_id, "", i == len(roots) - 1, set())
142+
143+
144+
def main() -> None:
145+
parser = argparse.ArgumentParser(description="Print legacy IST function ancestry tree from workflow_output.json")
146+
parser.add_argument("input_json", type=Path, help="Path to legacy workflow output JSON")
147+
args = parser.parse_args()
148+
149+
payload = _load_json(args.input_json)
150+
items = _iter_items(payload)
151+
152+
for idx, (label, item) in enumerate(items):
153+
if idx > 0:
154+
print()
155+
session_id = item.get("session_id", "unknown-session")
156+
print(f"=== {label} | mode=legacy | session_id={session_id} ===")
157+
nodes = _build_nodes(item)
158+
if not nodes:
159+
print("No function_ancestry metadata found in intermediate_steps.")
160+
continue
161+
_print_tree(nodes)
162+
163+
164+
if __name__ == "__main__":
165+
main()

0 commit comments

Comments
 (0)