|
| 1 | +""" |
| 2 | +Base utilities for code generation |
| 3 | +Shared functions for both PyTorch and TensorFlow code generation |
| 4 | +""" |
| 5 | + |
| 6 | +from collections import deque |
| 7 | +from typing import List, Dict, Any |
| 8 | + |
| 9 | + |
| 10 | +def topological_sort(nodes: List[Dict], edges: List[Dict]) -> List[Dict]: |
| 11 | + """ |
| 12 | + Sort nodes in topological order based on edges using Kahn's algorithm. |
| 13 | +
|
| 14 | + Args: |
| 15 | + nodes: List of node definitions |
| 16 | + edges: List of edge definitions |
| 17 | +
|
| 18 | + Returns: |
| 19 | + List of nodes in topological order |
| 20 | + """ |
| 21 | + node_map = {node['id']: node for node in nodes} |
| 22 | + |
| 23 | + # Build adjacency list and in-degree count |
| 24 | + graph = {node['id']: [] for node in nodes} |
| 25 | + in_degree = {node['id']: 0 for node in nodes} |
| 26 | + |
| 27 | + for edge in edges: |
| 28 | + source = edge.get('source') |
| 29 | + target = edge.get('target') |
| 30 | + if source in graph and target in graph: |
| 31 | + graph[source].append(target) |
| 32 | + in_degree[target] += 1 |
| 33 | + |
| 34 | + # Kahn's algorithm |
| 35 | + queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0]) |
| 36 | + sorted_ids = [] |
| 37 | + |
| 38 | + while queue: |
| 39 | + node_id = queue.popleft() |
| 40 | + sorted_ids.append(node_id) |
| 41 | + |
| 42 | + for neighbor in graph[node_id]: |
| 43 | + in_degree[neighbor] -= 1 |
| 44 | + if in_degree[neighbor] == 0: |
| 45 | + queue.append(neighbor) |
| 46 | + |
| 47 | + # Return nodes in sorted order |
| 48 | + return [node_map[node_id] for node_id in sorted_ids if node_id in node_map] |
| 49 | + |
| 50 | + |
| 51 | +def get_input_variable(incoming: List[str], var_map: Dict[str, str]) -> str: |
| 52 | + """ |
| 53 | + Determine input variable name based on incoming connections. |
| 54 | +
|
| 55 | + Args: |
| 56 | + incoming: List of incoming node IDs |
| 57 | + var_map: Map of node ID to variable name |
| 58 | +
|
| 59 | + Returns: |
| 60 | + Variable name or list of variable names for multiple inputs |
| 61 | + """ |
| 62 | + if not incoming: |
| 63 | + return 'x' |
| 64 | + elif len(incoming) == 1: |
| 65 | + return var_map.get(incoming[0], 'x') |
| 66 | + else: |
| 67 | + # Multiple inputs (for concat, add, etc.) |
| 68 | + input_vars = [var_map.get(src, 'x') for src in incoming] |
| 69 | + return f"[{', '.join(input_vars)}]" |
| 70 | + |
| 71 | + |
| 72 | +def get_node_type(node: Dict[str, Any]) -> str: |
| 73 | + """Extract node type from node definition""" |
| 74 | + return node.get('data', {}).get('blockType', 'unknown') |
| 75 | + |
| 76 | + |
| 77 | +def get_node_config(node: Dict[str, Any]) -> Dict[str, Any]: |
| 78 | + """Extract configuration from node definition""" |
| 79 | + return node.get('data', {}).get('config', {}) |
0 commit comments