Skip to content

Commit a4b6840

Browse files
tawsifkamalgithub-actions[bot]
authored andcommitted
Automated pre-commit update
1 parent 184ca87 commit a4b6840

File tree

2 files changed

+72
-92
lines changed

2 files changed

+72
-92
lines changed

examples/removing_import_loops_in_pytorch/import_loops.ipynb

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
"source": [
3131
"from codegen import Codebase\n",
3232
"import networkx as nx\n",
33-
"from utils import visualize_graph # utility function to visualize a networkx graph"
33+
"from utils import visualize_graph # utility function to visualize a networkx graph"
3434
]
3535
},
3636
{
@@ -74,9 +74,9 @@
7474
"# Add all edges to the graph\n",
7575
"for imp in codebase.imports:\n",
7676
" if imp.from_file and imp.to_file:\n",
77-
" edge_color = 'red' if imp.is_dynamic else 'black'\n",
78-
" edge_label = 'dynamic' if imp.is_dynamic else 'static'\n",
79-
" \n",
77+
" edge_color = \"red\" if imp.is_dynamic else \"black\"\n",
78+
" edge_label = \"dynamic\" if imp.is_dynamic else \"static\"\n",
79+
"\n",
8080
" # Store the import statement and its metadata\n",
8181
" G.add_edge(\n",
8282
" imp.to_file.filepath,\n",
@@ -85,7 +85,7 @@
8585
" label=edge_label,\n",
8686
" is_dynamic=imp.is_dynamic,\n",
8787
" import_statement=imp, # Store the whole import object\n",
88-
" key=id(imp.import_statement)\n",
88+
" key=id(imp.import_statement),\n",
8989
" )\n",
9090
"# Find strongly connected components\n",
9191
"cycles = [scc for scc in nx.strongly_connected_components(G) if len(scc) > 1]\n",
@@ -97,17 +97,15 @@
9797
"\n",
9898
" # Create subgraph for this cycle to count edges\n",
9999
" cycle_subgraph = G.subgraph(cycle)\n",
100-
" \n",
100+
"\n",
101101
" # Count total edges\n",
102102
" total_edges = cycle_subgraph.number_of_edges()\n",
103103
" print(f\"Total number of imports in cycle: {total_edges}\")\n",
104-
" \n",
104+
"\n",
105105
" # Count dynamic and static imports separately\n",
106-
" dynamic_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) \n",
107-
" if data.get('color') == 'red')\n",
108-
" static_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) \n",
109-
" if data.get('color') == 'black')\n",
110-
" \n",
106+
" dynamic_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get(\"color\") == \"red\")\n",
107+
" static_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get(\"color\") == \"black\")\n",
108+
"\n",
111109
" print(f\"Number of dynamic imports: {dynamic_imports}\")\n",
112110
" print(f\"Number of static imports: {static_imports}\")"
113111
]
@@ -133,6 +131,7 @@
133131
"import_loop = cycles[0]\n",
134132
"cycle_list = list(import_loop)\n",
135133
"\n",
134+
"\n",
136135
"def create_single_loop_graph(cycle):\n",
137136
" cycle_graph = nx.MultiDiGraph() # Changed to MultiDiGraph to support multiple edges\n",
138137
" cycle = list(cycle)\n",
@@ -144,10 +143,10 @@
144143
" # For each edge between these nodes\n",
145144
" for edge_key, edge_data in edge_data_dict.items():\n",
146145
" # Add edge with all its attributes to cycle graph\n",
147-
" cycle_graph.add_edge(cycle[i], \n",
148-
" cycle[j], \n",
149-
" **edge_data)\n",
146+
" cycle_graph.add_edge(cycle[i], cycle[j], **edge_data)\n",
150147
" return cycle_graph\n",
148+
"\n",
149+
"\n",
151150
"cycle_graph = create_single_loop_graph(cycle_list)\n",
152151
"visualize_graph(cycle_graph)"
153152
]
@@ -239,54 +238,47 @@
239238
"def find_problematic_import_loops(G, sccs):\n",
240239
" \"\"\"Find cycles where files have both static and dynamic imports between them.\"\"\"\n",
241240
" problematic_cycles = []\n",
242-
" \n",
241+
"\n",
243242
" for i, scc in enumerate(sccs):\n",
244-
" if i == 2: # skipping the second import loop as it's incredibly long (it's also invalid)\n",
243+
" if i == 2: # skipping the second import loop as it's incredibly long (it's also invalid)\n",
245244
" continue\n",
246245
" mixed_import_files = {} # (from_file, to_file) -> {dynamic: count, static: count}\n",
247-
" \n",
246+
"\n",
248247
" # Check all file pairs in the cycle\n",
249248
" for from_file in scc:\n",
250249
" for to_file in scc:\n",
251250
" if G.has_edge(from_file, to_file):\n",
252251
" # Get all edges between these files\n",
253252
" edges = G.get_edge_data(from_file, to_file)\n",
254-
" \n",
253+
"\n",
255254
" # Count imports by type\n",
256-
" dynamic_count = sum(1 for e in edges.values() if e['color'] == 'red')\n",
257-
" static_count = sum(1 for e in edges.values() if e['color'] == 'black')\n",
258-
" \n",
255+
" dynamic_count = sum(1 for e in edges.values() if e[\"color\"] == \"red\")\n",
256+
" static_count = sum(1 for e in edges.values() if e[\"color\"] == \"black\")\n",
257+
"\n",
259258
" # If we have both types between same files, this is problematic\n",
260259
" if dynamic_count > 0 and static_count > 0:\n",
261-
" mixed_import_files[(from_file, to_file)] = {\n",
262-
" 'dynamic': dynamic_count,\n",
263-
" 'static': static_count,\n",
264-
" 'edges': edges\n",
265-
" }\n",
266-
" \n",
260+
" mixed_import_files[(from_file, to_file)] = {\"dynamic\": dynamic_count, \"static\": static_count, \"edges\": edges}\n",
261+
"\n",
267262
" if mixed_import_files:\n",
268-
" problematic_cycles.append({\n",
269-
" 'files': scc,\n",
270-
" 'mixed_imports': mixed_import_files,\n",
271-
" 'index': i\n",
272-
" })\n",
273-
" \n",
263+
" problematic_cycles.append({\"files\": scc, \"mixed_imports\": mixed_import_files, \"index\": i})\n",
264+
"\n",
274265
" # Print findings\n",
275266
" print(f\"Found {len(problematic_cycles)} cycles with mixed imports:\")\n",
276267
" for i, cycle in enumerate(problematic_cycles):\n",
277268
" print(f\"\\n⚠️ Problematic Cycle #{i + 1}:\")\n",
278-
" print(f\"\\n⚠️ Index #{cycle[\"index\"]}:\")\n",
269+
" print(f\"\\n⚠️ Index #{cycle['index']}:\")\n",
279270
" print(f\"Size: {len(cycle['files'])} files\")\n",
280-
" \n",
281-
" for (from_file, to_file), data in cycle['mixed_imports'].items():\n",
271+
"\n",
272+
" for (from_file, to_file), data in cycle[\"mixed_imports\"].items():\n",
282273
" print(\"\\n📁 Mixed imports detected:\")\n",
283274
" print(f\" From: {from_file}\")\n",
284275
" print(f\" To: {to_file}\")\n",
285276
" print(f\" Dynamic imports: {data['dynamic']}\")\n",
286277
" print(f\" Static imports: {data['static']}\")\n",
287-
" \n",
278+
"\n",
288279
" return problematic_cycles\n",
289280
"\n",
281+
"\n",
290282
"problematic_loops = find_problematic_import_loops(G, cycles)"
291283
]
292284
},
@@ -388,7 +380,7 @@
388380
" if imp.imported_symbol:\n",
389381
" symbols_to_move.add(imp.imported_symbol)\n",
390382
"\n",
391-
"#Move identified symbols to utils file\n",
383+
"# Move identified symbols to utils file\n",
392384
"for symbol in symbols_to_move:\n",
393385
" symbol.move_to_file(utils_file)\n",
394386
"\n",

examples/removing_import_loops_in_pytorch/utils.py

Lines changed: 41 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,75 +3,63 @@ def visualize_graph(graph):
33
Visualize SCC using Graphviz with a strictly enforced circular layout
44
"""
55
import pygraphviz as pgv
6-
6+
77
# Create a new pygraphviz graph directly (instead of converting)
88
A = pgv.AGraph(strict=False, directed=True)
9-
9+
1010
# Set graph attributes for strict circular layout
11-
A.graph_attr.update({
12-
'layout': 'circo',
13-
'root': 'circle',
14-
'splines': 'curved',
15-
'overlap': 'false',
16-
'sep': '+25,25',
17-
'pad': '0.5',
18-
'ranksep': '2.0',
19-
'nodesep': '0.8',
20-
'mindist': '2.0',
21-
'start': 'regular',
22-
'ordering': 'out',
23-
'concentrate': 'false',
24-
'ratio': '1.0',
25-
})
26-
11+
A.graph_attr.update(
12+
{
13+
"layout": "circo",
14+
"root": "circle",
15+
"splines": "curved",
16+
"overlap": "false",
17+
"sep": "+25,25",
18+
"pad": "0.5",
19+
"ranksep": "2.0",
20+
"nodesep": "0.8",
21+
"mindist": "2.0",
22+
"start": "regular",
23+
"ordering": "out",
24+
"concentrate": "false",
25+
"ratio": "1.0",
26+
}
27+
)
28+
2729
# Set node attributes for consistent sizing
28-
A.node_attr.update({
29-
'shape': 'circle',
30-
'fixedsize': 'true',
31-
'width': '1.5',
32-
'height': '1.5',
33-
'style': 'filled',
34-
'fillcolor': 'lightblue',
35-
'fontsize': '11',
36-
'fontname': 'Arial'
37-
})
38-
30+
A.node_attr.update({"shape": "circle", "fixedsize": "true", "width": "1.5", "height": "1.5", "style": "filled", "fillcolor": "lightblue", "fontsize": "11", "fontname": "Arial"})
31+
3932
# Set default edge attributes
40-
A.edge_attr.update({
41-
'penwidth': '1.5',
42-
'arrowsize': '0.8',
43-
'len': '2.0',
44-
'weight': '1',
45-
'dir': 'forward'
46-
})
47-
33+
A.edge_attr.update({"penwidth": "1.5", "arrowsize": "0.8", "len": "2.0", "weight": "1", "dir": "forward"})
34+
4835
# Add nodes first
4936
for node in graph.nodes():
50-
short_name = node.split('/')[-1]
37+
short_name = node.split("/")[-1]
5138
A.add_node(node, label=short_name)
52-
39+
5340
# Add edges with their attributes
5441
for u, v, key, data in graph.edges(data=True, keys=True):
5542
# Create a unique key for this edge
5643
edge_key = f"{u}_{v}_{key}"
57-
44+
5845
# Set edge attributes based on the data
5946
edge_attrs = {
60-
'key': edge_key, # Ensure unique edge
61-
'color': 'red' if data.get('color') == 'red' else '#666666',
62-
'style': 'dashed' if data.get('color') == 'red' else 'solid',
63-
'label': 'dynamic' if data.get('color') == 'red' else '',
64-
'fontcolor': 'red' if data.get('color') == 'red' else '#666666',
65-
'fontsize': '10'
47+
"key": edge_key, # Ensure unique edge
48+
"color": "red" if data.get("color") == "red" else "#666666",
49+
"style": "dashed" if data.get("color") == "red" else "solid",
50+
"label": "dynamic" if data.get("color") == "red" else "",
51+
"fontcolor": "red" if data.get("color") == "red" else "#666666",
52+
"fontsize": "10",
6653
}
67-
54+
6855
A.add_edge(u, v, **edge_attrs)
69-
56+
7057
# Force circo layout with specific settings
71-
A.layout(prog='circo')
72-
58+
A.layout(prog="circo")
59+
7360
# Save with a larger size
74-
A.draw('import_cycle.png', format='png', prog='circo', args='-Gsize=12,12!')
75-
61+
A.draw("import_cycle.png", format="png", prog="circo", args="-Gsize=12,12!")
62+
7663
from IPython.display import Image
77-
return Image('import_cycle.png')
64+
65+
return Image("import_cycle.png")

0 commit comments

Comments
 (0)