|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +"""Print a readable function ancestry tree from ATIF workflow output. |
| 17 | +
|
| 18 | +Example: |
| 19 | + python packages/nvidia_nat_eval/scripts/print_atif_function_tree.py \ |
| 20 | + ".tmp/nat/examples/evaluation_and_profiling/simple_web_query_eval/atif/workflow_output_atif.json" |
| 21 | +""" |
| 22 | + |
| 23 | +from __future__ import annotations |
| 24 | + |
| 25 | +import argparse |
| 26 | +import json |
| 27 | +from collections import defaultdict |
| 28 | +from dataclasses import dataclass |
| 29 | +from pathlib import Path |
| 30 | +from typing import Any |
| 31 | + |
| 32 | + |
| 33 | +@dataclass |
| 34 | +class NodeStats: |
| 35 | + function_id: str |
| 36 | + function_name: str |
| 37 | + parent_id: str | None |
| 38 | + parent_name: str | None |
| 39 | + seen_in_step_ancestry: int = 0 |
| 40 | + seen_in_tool_ancestry: int = 0 |
| 41 | + |
| 42 | + |
| 43 | +def _load_json(path: Path) -> Any: |
| 44 | + return json.loads(path.read_text(encoding="utf-8")) |
| 45 | + |
| 46 | + |
| 47 | +def _iter_trajectories(payload: Any) -> list[tuple[str, dict[str, Any]]]: |
| 48 | + """Normalize ATIF payload to (label, trajectory_dict).""" |
| 49 | + if isinstance(payload, list): |
| 50 | + out: list[tuple[str, dict[str, Any]]] = [] |
| 51 | + for i, item in enumerate(payload): |
| 52 | + if not isinstance(item, dict): |
| 53 | + continue |
| 54 | + if isinstance(item.get("trajectory"), dict): |
| 55 | + label = f"item={item.get('item_id', i)}" |
| 56 | + out.append((label, item["trajectory"])) |
| 57 | + elif isinstance(item.get("steps"), list): |
| 58 | + out.append((f"trajectory={i}", item)) |
| 59 | + return out |
| 60 | + |
| 61 | + if isinstance(payload, dict): |
| 62 | + if isinstance(payload.get("trajectory"), dict): |
| 63 | + return [(f"item={payload.get('item_id', '0')}", payload["trajectory"])] |
| 64 | + if isinstance(payload.get("steps"), list): |
| 65 | + return [("trajectory=0", payload)] |
| 66 | + |
| 67 | + raise ValueError("Unsupported ATIF JSON shape. Expected trajectory or eval-sample payload.") |
| 68 | + |
| 69 | + |
| 70 | +def _add_ancestry(nodes: dict[str, NodeStats], fn: dict[str, Any], from_tool: bool) -> None: |
| 71 | + function_id = str(fn.get("function_id") or "") |
| 72 | + function_name = str(fn.get("function_name") or "") |
| 73 | + parent_id = fn.get("parent_id") |
| 74 | + parent_name = fn.get("parent_name") |
| 75 | + if not function_id or not function_name: |
| 76 | + return |
| 77 | + |
| 78 | + if function_id not in nodes: |
| 79 | + nodes[function_id] = NodeStats( |
| 80 | + function_id=function_id, |
| 81 | + function_name=function_name, |
| 82 | + parent_id=str(parent_id) if parent_id is not None else None, |
| 83 | + parent_name=str(parent_name) if parent_name is not None else None, |
| 84 | + ) |
| 85 | + |
| 86 | + if from_tool: |
| 87 | + nodes[function_id].seen_in_tool_ancestry += 1 |
| 88 | + else: |
| 89 | + nodes[function_id].seen_in_step_ancestry += 1 |
| 90 | + |
| 91 | + |
| 92 | +def _build_nodes(trajectory: dict[str, Any]) -> dict[str, NodeStats]: |
| 93 | + nodes: dict[str, NodeStats] = {} |
| 94 | + for step in trajectory.get("steps", []): |
| 95 | + extra = step.get("extra") or {} |
| 96 | + ancestry = extra.get("ancestry") |
| 97 | + if isinstance(ancestry, dict): |
| 98 | + _add_ancestry(nodes, ancestry.get("function_ancestry") or {}, from_tool=False) |
| 99 | + for tool_ancestry in extra.get("tool_ancestry") or []: |
| 100 | + if isinstance(tool_ancestry, dict): |
| 101 | + _add_ancestry(nodes, tool_ancestry.get("function_ancestry") or {}, from_tool=True) |
| 102 | + return nodes |
| 103 | + |
| 104 | + |
| 105 | +def _print_tree(nodes: dict[str, NodeStats]) -> None: |
| 106 | + by_parent: dict[str, list[str]] = defaultdict(list) |
| 107 | + for function_id, node in nodes.items(): |
| 108 | + parent = node.parent_id or "root" |
| 109 | + if parent == function_id: |
| 110 | + # Defensive guard against malformed self-parent links. |
| 111 | + parent = "root" |
| 112 | + by_parent[parent].append(function_id) |
| 113 | + |
| 114 | + for child_ids in by_parent.values(): |
| 115 | + child_ids.sort(key=lambda fid: nodes[fid].function_name) |
| 116 | + |
| 117 | + roots = [ |
| 118 | + fid for fid, node in nodes.items() |
| 119 | + if (node.parent_id in (None, "", "root")) and fid != "root" |
| 120 | + ] |
| 121 | + roots.sort(key=lambda fid: nodes[fid].function_name) |
| 122 | + |
| 123 | + def rec(function_id: str, prefix: str, is_last: bool, visited: set[str]) -> None: |
| 124 | + if function_id in visited: |
| 125 | + branch = "└─ " if is_last else "├─ " |
| 126 | + print(f"{prefix}{branch}<cycle> [{function_id}]") |
| 127 | + return |
| 128 | + visited = set(visited) |
| 129 | + visited.add(function_id) |
| 130 | + |
| 131 | + node = nodes[function_id] |
| 132 | + branch = "└─ " if is_last else "├─ " |
| 133 | + counts = [] |
| 134 | + if node.seen_in_step_ancestry: |
| 135 | + counts.append(f"steps={node.seen_in_step_ancestry}") |
| 136 | + if node.seen_in_tool_ancestry: |
| 137 | + counts.append(f"tools={node.seen_in_tool_ancestry}") |
| 138 | + counts_str = f" ({', '.join(counts)})" if counts else "" |
| 139 | + print(f"{prefix}{branch}{node.function_name} [{node.function_id}]{counts_str}") |
| 140 | + |
| 141 | + children = by_parent.get(function_id, []) |
| 142 | + child_prefix = prefix + (" " if is_last else "│ ") |
| 143 | + for i, child_id in enumerate(children): |
| 144 | + rec(child_id, child_prefix, i == len(children) - 1, visited) |
| 145 | + |
| 146 | + print("root") |
| 147 | + if not roots and "root" in nodes: |
| 148 | + roots = ["root"] |
| 149 | + for i, root_id in enumerate(roots): |
| 150 | + rec(root_id, "", i == len(roots) - 1, set()) |
| 151 | + |
| 152 | + |
| 153 | +def main() -> None: |
| 154 | + parser = argparse.ArgumentParser(description="Print ATIF function ancestry tree from workflow_output_atif.json") |
| 155 | + parser.add_argument("input_json", type=Path, help="Path to ATIF workflow output JSON") |
| 156 | + args = parser.parse_args() |
| 157 | + |
| 158 | + payload = _load_json(args.input_json) |
| 159 | + trajectories = _iter_trajectories(payload) |
| 160 | + |
| 161 | + for idx, (label, trajectory) in enumerate(trajectories): |
| 162 | + if idx > 0: |
| 163 | + print() |
| 164 | + session_id = trajectory.get("session_id", "unknown-session") |
| 165 | + print(f"=== {label} | mode=atif | session_id={session_id} ===") |
| 166 | + nodes = _build_nodes(trajectory) |
| 167 | + if not nodes: |
| 168 | + print("No ancestry metadata found in step.extra.") |
| 169 | + continue |
| 170 | + _print_tree(nodes) |
| 171 | + |
| 172 | + |
| 173 | +if __name__ == "__main__": |
| 174 | + main() |
0 commit comments