Skip to content

Commit 6c62f81

Browse files
authored
Merge pull request #46 from ForgeOpus/claude/refactor-pytorch-codegen-ZROiF
Claude/refactor pytorch codegen zr oi f
2 parents ecfe249 + df2f1b5 commit 6c62f81

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+6487
-2592
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Code generation orchestration package"""
2+
3+
from .pytorch_orchestrator import PyTorchCodeOrchestrator
4+
from .tensorflow_orchestrator import TensorFlowCodeOrchestrator
5+
6+
__all__ = ['PyTorchCodeOrchestrator', 'TensorFlowCodeOrchestrator']
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
Base utilities for code generation
3+
Shared functions for both PyTorch and TensorFlow code generation
4+
"""
5+
6+
from collections import deque
7+
from typing import List, Dict, Any
8+
9+
10+
def topological_sort(nodes: List[Dict], edges: List[Dict]) -> List[Dict]:
11+
"""
12+
Sort nodes in topological order based on edges using Kahn's algorithm.
13+
14+
Args:
15+
nodes: List of node definitions
16+
edges: List of edge definitions
17+
18+
Returns:
19+
List of nodes in topological order
20+
"""
21+
node_map = {node['id']: node for node in nodes}
22+
23+
# Build adjacency list and in-degree count
24+
graph = {node['id']: [] for node in nodes}
25+
in_degree = {node['id']: 0 for node in nodes}
26+
27+
for edge in edges:
28+
source = edge.get('source')
29+
target = edge.get('target')
30+
if source in graph and target in graph:
31+
graph[source].append(target)
32+
in_degree[target] += 1
33+
34+
# Kahn's algorithm
35+
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
36+
sorted_ids = []
37+
38+
while queue:
39+
node_id = queue.popleft()
40+
sorted_ids.append(node_id)
41+
42+
for neighbor in graph[node_id]:
43+
in_degree[neighbor] -= 1
44+
if in_degree[neighbor] == 0:
45+
queue.append(neighbor)
46+
47+
# Return nodes in sorted order
48+
return [node_map[node_id] for node_id in sorted_ids if node_id in node_map]
49+
50+
51+
def get_input_variable(incoming: List[str], var_map: Dict[str, str]) -> str:
52+
"""
53+
Determine input variable name based on incoming connections.
54+
55+
Args:
56+
incoming: List of incoming node IDs
57+
var_map: Map of node ID to variable name
58+
59+
Returns:
60+
Variable name or list of variable names for multiple inputs
61+
"""
62+
if not incoming:
63+
return 'x'
64+
elif len(incoming) == 1:
65+
return var_map.get(incoming[0], 'x')
66+
else:
67+
# Multiple inputs (for concat, add, etc.)
68+
input_vars = [var_map.get(src, 'x') for src in incoming]
69+
return f"[{', '.join(input_vars)}]"
70+
71+
72+
def get_node_type(node: Dict[str, Any]) -> str:
73+
"""Extract node type from node definition"""
74+
return node.get('data', {}).get('blockType', 'unknown')
75+
76+
77+
def get_node_config(node: Dict[str, Any]) -> Dict[str, Any]:
78+
"""Extract configuration from node definition"""
79+
return node.get('data', {}).get('config', {})

0 commit comments

Comments
 (0)