Skip to content

Commit 86653a4

Browse files
[TK] Subgraph Tracing to support control flow (#350)
This PR supports tracing a function within a separate tracer as a "subgraph". This allows us to trace loop bodies instead of unrolling them. Along with subgraph tracing, this patch adds several features to enable tracing a gemm kernel: - Instead of doing slicing, add explicit tkl.load/tkl.store ops. Personally, I feel slicing may not be the way to go forward as we always have a constant sized output from the slice. If we go with slicing, we have to analyze if it's constant sized which is not worth it. - Add support for tkl.constant, tkl.dot, tkl.for_loop. - Add a new "Vector" class, which is a tensor like class supporting computations over it. I'm not a fan of using pytorch ops directly since I don't get control over the op signature. I did not add support for eager executing these operations, only compile mode. All of these newly added ops can be eagerly executed, just haven't added the support in this patch. --------- Co-authored-by: Stella Laurenzo <[email protected]>
1 parent 6b21267 commit 86653a4

File tree

18 files changed

+896
-65
lines changed

18 files changed

+896
-65
lines changed

python/shark_turbine/kernel/_support/indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class NotSetType:
4141
class ElementType(ABC):
4242
@staticmethod
4343
def cast(something) -> "ElementType":
44-
if isinstance(something, torch.dtyp):
44+
if isinstance(something, torch.dtype):
4545
return TorchElementType(something)
4646
else:
4747
raise TypeError(
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from typing import (
2+
Optional,
3+
TypeVar,
4+
Callable,
5+
Type,
6+
assert_type,
7+
cast,
8+
List,
9+
Dict,
10+
Tuple,
11+
)
12+
import random
13+
import contextlib
14+
15+
import torch.fx as fx
16+
import torch.utils._pytree as pytree
17+
18+
19+
class RegionGraph:
20+
def __init__(self):
21+
self.tracers: List["SubgraphTracer"] = []
22+
self.subgraphs: Dict[str, fx.Graph] = dict()
23+
self.inner_freevars: Dict[fx.Graph, List[fx.Proxy]] = dict()
24+
25+
@property
26+
def root_tracer(self) -> "SubgraphTracer":
27+
return self.tracers[0]
28+
29+
@property
30+
def current_tracer(self) -> "SubgraphTracer":
31+
return self.tracers[-1]
32+
33+
def create_proxy(self, *args, **kwargs):
34+
return self.current_tracer.create_proxy(*args, **kwargs)
35+
36+
def create_node(self, *args, **kwargs):
37+
return self.current_tracer.create_node(*args, **kwargs)
38+
39+
def create_arg(self, *args, **kwargs):
40+
return self.current_tracer.create_arg(*args, **kwargs)
41+
42+
def new_subtracer(
43+
self, region_graph: "RegionGraph", parent: Optional["SubgraphTracer"] = None
44+
) -> "SubgraphTracer":
45+
...
46+
47+
### ========================================================================
48+
### Subgraph Tracing
49+
### ========================================================================
50+
def add_subgraph(
51+
self, name: str, graph: fx.Graph, inner_freevars: List[fx.Proxy]
52+
) -> str:
53+
i = 0
54+
while True:
55+
candidate_name = f"{name}_{i}"
56+
i += 1
57+
if candidate_name not in self.subgraphs:
58+
self.subgraphs[candidate_name] = graph
59+
self.inner_freevars[graph] = inner_freevars
60+
return candidate_name
61+
62+
@contextlib.contextmanager
63+
def subtracer(self):
64+
if self.tracers:
65+
new_tracer = self.new_subtracer(self, self.current_tracer)
66+
else:
67+
new_tracer = self.new_subtracer(self)
68+
self.tracers.append(new_tracer)
69+
yield new_tracer
70+
self.tracers.pop()
71+
72+
def __str__(self):
73+
out = ""
74+
for name, subgraph in self.subgraphs.items():
75+
out += f"{name}:"
76+
out += str(subgraph)
77+
out += "\n"
78+
return out
79+
80+
81+
class SubgraphTracer(fx.Tracer):
82+
def __init__(
83+
self, region_graph: RegionGraph, parent: Optional["SubgraphTracer"] = None
84+
):
85+
super().__init__()
86+
self.graph = fx.Graph()
87+
self.region_graph = region_graph
88+
self.parent = parent
89+
self.lifted_freevars: Dict[fx.Proxy, fx.Proxy] = {}
90+
91+
def trace(self, *args, **kwargs) -> Tuple[str, List[fx.Proxy]]:
92+
traced = super().trace(*args, **kwargs)
93+
inner_freevars = list(self.lifted_freevars.values())
94+
implicit_capture = list(self.lifted_freevars.keys())
95+
subgraph_name = self.region_graph.add_subgraph("region", traced, inner_freevars)
96+
return subgraph_name, implicit_capture
97+
98+
def _create_graph_input(self, name: str, type_expr=None) -> fx.Proxy:
99+
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
100+
# Can use this to check where the freevar has been lifted from.
101+
proxy.node.meta["lifted"] = None
102+
return proxy
103+
104+
def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy):
105+
# It makes no sense for the root graph to have free variables
106+
assert self.parent is not None, "Cannot lift freevars to input in root tracer"
107+
108+
# If the freevar has already been lifted, return the lifted version.
109+
if proxy in self.lifted_freevars:
110+
return self.lifted_freevars[proxy]
111+
112+
# Otherwise, create a new input and store it.
113+
new_proxy = self._create_graph_input(proxy.node.name, proxy.node.type)
114+
self.lifted_freevars[proxy] = new_proxy
115+
116+
# Propagate freevar usage upwards.
117+
if self.parent is not None and proxy.tracer != self.parent:
118+
self.parent._lift_tracked_freevar_to_input(proxy)
119+
return new_proxy
120+
121+
def _maybe_lift_tracked_freevar_to_input(self, arg):
122+
"""
123+
If arg is a free variable, then lift it to be an input.
124+
Returns the new lifted arg (if lifted), else the original arg.
125+
"""
126+
if not isinstance(arg, fx.Proxy):
127+
return arg
128+
elif arg.tracer == self:
129+
return arg
130+
else:
131+
return self._lift_tracked_freevar_to_input(arg)
132+
133+
def create_proxy(
134+
self,
135+
kind,
136+
target,
137+
args,
138+
kwargs,
139+
name=None,
140+
type_expr=None,
141+
proxy_factor_fn=None,
142+
):
143+
if self.parent is not None:
144+
flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
145+
new_flat_args = []
146+
for arg in flat_args:
147+
maybe_new_arg = self._maybe_lift_tracked_freevar_to_input(arg)
148+
new_flat_args.append(maybe_new_arg)
149+
args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)
150+
151+
rv = super().create_proxy(
152+
kind,
153+
target,
154+
args,
155+
kwargs,
156+
name,
157+
type_expr,
158+
proxy_factor_fn,
159+
)
160+
161+
return rv

python/shark_turbine/kernel/_support/tracing.py

Lines changed: 131 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional, TypeVar, Callable, Type, assert_type, cast
2+
from typing import (
3+
Optional,
4+
TypeVar,
5+
Callable,
6+
Type,
7+
assert_type,
8+
cast,
9+
List,
10+
Dict,
11+
Tuple,
12+
Any,
13+
)
314

415
import functools
516
import warnings
17+
import contextlib
18+
import torch.utils._pytree as pytree
19+
import random
620

721
import torch.fx as fx
822

@@ -15,8 +29,12 @@
1529

1630
from ..lang.types import (
1731
Index,
32+
Vector,
1833
)
1934

35+
from .regions import RegionGraph, SubgraphTracer
36+
37+
2038
from .. import ops
2139
from ..ops.base import (
2240
OpDispatcher,
@@ -26,6 +44,20 @@
2644

2745
TCallable = TypeVar("TCallable", bound=Callable)
2846

47+
###############################################################################
48+
# Kernel Region Graph
49+
###############################################################################
50+
51+
52+
class KernelRegionGraph(RegionGraph):
53+
def new_subtracer(
54+
self,
55+
region_graph: "RegionGraph",
56+
parent: Optional["SubgraphTracer"] = None,
57+
) -> "KernelTracer":
58+
return KernelTracer(region_graph, parent=parent)
59+
60+
2961
###############################################################################
3062
# Tracing machinery
3163
###############################################################################
@@ -35,7 +67,10 @@ class KernelBufferProxy(fx.Proxy):
3567
"""Custom proxy for KernelBuffer so that we can override special methods."""
3668

3769
def __init__(
38-
self, node: fx.Node, tracer: "KernelTracer", orig_type: Type[KernelBuffer]
70+
self,
71+
node: fx.Node,
72+
tracer: "KernelTracer",
73+
orig_type: Type[KernelBuffer],
3974
):
4075
super().__init__(node, tracer)
4176
self._orig_type = orig_type
@@ -50,9 +85,10 @@ def __setitem__(self, key, item):
5085
ops.kernel_buffer_setitem(self, key, item)
5186

5287

53-
class KernelTracer(fx.Tracer):
88+
class KernelTracer(SubgraphTracer):
5489
"""Custom Tracer for generating a trace of a kernel computation."""
5590

91+
# Register our custom proxies.
5692
def proxy(self, node: fx.Node) -> fx.Proxy:
5793
t = node.type
5894
if t is not None and issubclass(t, KernelBuffer):
@@ -61,8 +97,15 @@ def proxy(self, node: fx.Node) -> fx.Proxy:
6197

6298

6399
class CapturedTrace:
64-
def __init__(self, gm: fx.GraphModule):
65-
self.gm = gm
100+
def __init__(self, region_graph: RegionGraph, root_graph: str):
101+
self.region_graph = region_graph
102+
self.root_graph = root_graph
103+
104+
def get_subgraph(self, name: str) -> fx.Graph:
105+
return self.region_graph.subgraphs[name]
106+
107+
def get_root_graph(self) -> fx.Graph:
108+
return self.get_subgraph(self.root_graph)
66109

67110

68111
###############################################################################
@@ -109,18 +152,22 @@ def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, ite
109152

110153

111154
class CompiledContext(BaseContext):
112-
def __init__(self, tracer: KernelTracer, *, grid_type: Type[Grid]):
155+
def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
113156
super().__init__(eager=False)
114-
self.tracer = tracer
157+
self.region_graph = region_graph
115158
self.grid_type = grid_type
116159

160+
### ========================================================================
161+
### Core Operations
162+
### ========================================================================
163+
117164
def handle_thread_program_id(self, op, axis: int) -> Index:
118165
grid_shape = self.grid_type.symbolic_shape
119166
if axis < 0 or axis >= len(grid_shape):
120167
raise IndexError(
121168
f"Illegal index into grid of rank {len(grid_shape)}: {axis}"
122169
)
123-
proxy = self.tracer.create_proxy(
170+
proxy = self.region_graph.create_proxy(
124171
"call_function",
125172
op,
126173
args=(axis,),
@@ -130,21 +177,95 @@ def handle_thread_program_id(self, op, axis: int) -> Index:
130177
return proxy
131178

132179
def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key):
133-
return self.tracer.create_proxy(
180+
return self.region_graph.create_proxy(
134181
"call_function",
135182
op,
136183
args=(kernel_buffer, key),
137184
kwargs={},
138185
)
139186

140187
def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, item):
141-
self.tracer.create_proxy(
188+
self.region_graph.create_proxy(
142189
"call_function",
143190
target=op,
144191
args=(kernel_buffer, key, item),
145192
kwargs={},
146193
)
147194

195+
### ========================================================================
196+
### Memory Operations
197+
### ========================================================================
198+
def handle_kernel_buffer_load(self, op, kernel_buffer, multi_index, shape):
199+
return self.region_graph.create_proxy(
200+
"call_function",
201+
target=op,
202+
args=(kernel_buffer, multi_index, shape),
203+
kwargs={},
204+
)
205+
206+
def handle_kernel_buffer_store(self, op, kernel_buffer, multi_index, item):
207+
self.region_graph.create_proxy(
208+
"call_function",
209+
target=op,
210+
args=(kernel_buffer, multi_index, item),
211+
kwargs={},
212+
)
213+
214+
### ========================================================================
215+
### Control Flow Operations
216+
### ========================================================================
217+
218+
def handle_for_loop(self, op, start, stop=None, step=None, init_args=[]):
219+
if stop is None:
220+
stop = start
221+
start = 0
222+
if step is None:
223+
step = 1
224+
225+
def wrapper(f):
226+
with self.region_graph.subtracer() as subtracer:
227+
subgraph_name, implicit_capture = subtracer.trace(f)
228+
# Create a call to this subgraph
229+
ret = self.region_graph.create_proxy(
230+
"call_function",
231+
target=op,
232+
name="for_loop",
233+
args=(start, stop, step, init_args),
234+
kwargs={
235+
"subgraph": subgraph_name,
236+
"implicit_capture": implicit_capture,
237+
},
238+
)
239+
return ret
240+
241+
return wrapper
242+
243+
### ========================================================================
244+
### Math Operations
245+
### ========================================================================
246+
247+
def handle_vector_constant(
248+
self, op, shape: Tuple[int, ...], dtype, value: int | float
249+
):
250+
return self.region_graph.create_proxy(
251+
"call_function",
252+
target=op,
253+
args=(shape, dtype, value),
254+
kwargs={},
255+
)
256+
257+
### ========================================================================
258+
### Reduction Operations
259+
### ========================================================================
260+
261+
def handle_vector_dot(self, op, lhs, rhs, acc):
262+
return self.region_graph.create_proxy(
263+
"call_function",
264+
target=op,
265+
args=(lhs, rhs, acc),
266+
kwargs={},
267+
)
268+
148269

149270
###############################################################################
150271
# Launch context

0 commit comments

Comments
 (0)