Skip to content

Commit 157688a

Browse files
authored
[Op-Authoring] Adding mapping from torch ops to ExprHandles (#205)
* [Op-Authoring] Adding mapping from torch ops to ExprHandles Dependent on PR - pytorch/pytorch#66612 * Cleaing up and bug fix * Retrigger cI * Retirgger CI * Retrigger CI
1 parent 79d31a7 commit 157688a

File tree

2 files changed

+183
-3
lines changed

2 files changed

+183
-3
lines changed

functorch/_src/operator_authoring.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inspect
44
import itertools
55
from typing import Callable, List, Union, Tuple, Optional
6+
import operator
67

78
import torch
89
from torch import fx
@@ -12,6 +13,51 @@
1213
FOLD_ALIASES = True
1314
_SHAPE_TYPES = {"one", "other"}
1415
_STRIDE_TYPES = {"zero", "one", "contiguous", "transposed_contiguous", "as_arg"}
16+
_identity = lambda x: x
17+
_TORCH_TO_EXPR_MAP = {
18+
"sin": _te.sin,
19+
"cos": _te.cos,
20+
"tan": _te.tan,
21+
"asin": _te.asin,
22+
"acos": _te.acos,
23+
"atan": _te.atan,
24+
"sinh": _te.sinh,
25+
"cosh": _te.cosh,
26+
"tanh": _te.tanh,
27+
"sigmoid": _te.sigmoid,
28+
"exp": _te.exp,
29+
"expm1": _te.expm1,
30+
"abs": _te.abs,
31+
"log": _te.log,
32+
"log2": _te.log2,
33+
"log10": _te.log10,
34+
"log1p": _te.log1p,
35+
"erf": _te.erf,
36+
"erfc": _te.erfc,
37+
"sqrt": _te.sqrt,
38+
"rsqrt": _te.rsqrt,
39+
"ceil": _te.ceil,
40+
"floor": _te.floor,
41+
"round": _te.round,
42+
"trunc": _te.trunc,
43+
"frac": _te.frac,
44+
"lgamma": _te.lgamma,
45+
"isnan": _te.isnan,
46+
"add": operator.add,
47+
"sub": operator.sub,
48+
"subtract": operator.sub,
49+
"mul": operator.mul,
50+
"multiply": operator.mul,
51+
"divide": operator.truediv,
52+
"div": operator.truediv,
53+
"remainder": _te.remainder,
54+
"fmod": _te.fmod,
55+
"pow": _te.pow,
56+
"atan2": _te.atan2,
57+
"detach": _identity,
58+
"neg": lambda x: _create_constant(0.0, torch.float32) - x,
59+
}
60+
1561
_int = _te.ExprHandle.int
1662

1763

@@ -40,8 +86,8 @@ def _combine_dtype(a: torch.dtype, b: torch.dtype):
4086
).dtype
4187

4288

43-
def _fx_replace_constants(fn: Callable, dtype: torch.dtype):
44-
"""Convert the constants in the user function to TensorExpr constants"""
89+
def _fx_to_expr(fn: Callable, dtype: torch.dtype):
90+
"""Convert the fx graph to equivalent Tensor Expr"""
4591

4692
def apply(arg):
4793
if isinstance(arg, (int, float)):
@@ -52,6 +98,27 @@ def apply(arg):
5298
for node in list(gm.graph.nodes):
5399
with gm.graph.inserting_before(node):
54100
node.args = tuple(apply(a) for a in node.args)
101+
if node.op == "call_function":
102+
if node.target.__name__ not in _TORCH_TO_EXPR_MAP:
103+
raise NotImplementedError(
104+
"Missing mapping from op ",
105+
node.target.__name__,
106+
" to Tensor Expr",
107+
)
108+
109+
# Get the parser function to parse the torch op to tensor expr handle
110+
111+
def _parser(*args, op_name):
112+
return _TORCH_TO_EXPR_MAP[op_name](*args)
113+
114+
new_node = gm.graph.create_node(
115+
"call_function",
116+
_parser,
117+
node.args,
118+
{"op_name": node.target.__name__},
119+
)
120+
node.replace_all_uses_with(new_node)
121+
gm.graph.erase_node(node)
55122
gm.recompile()
56123
return gm
57124

@@ -298,7 +365,7 @@ def compute_code(self):
298365
_te.Cast.make(self.dtype, buf.load(self.indexing(stride)))
299366
for buf, stride in zip(input_bufs, input_strides)
300367
]
301-
val = _fx_replace_constants(self.pointwise_fn, self.dtype)(*inputs)
368+
val = _fx_to_expr(self.pointwise_fn, self.dtype)(*inputs)
302369
out = _te.Block(
303370
[
304371
buf.store(self.indexing(stride), val)

test/test_operator_authoring.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,119 @@ def example(x):
126126
x = torch.randn(8, device=self.device)
127127
torch.testing.assert_allclose(x + 3, graph(x))
128128

129+
def test_unary_ops(self):
130+
unary_operators = [
131+
torch.sin,
132+
torch.cos,
133+
torch.tan,
134+
torch.asin,
135+
torch.acos,
136+
torch.atan,
137+
torch.sinh,
138+
torch.cosh,
139+
torch.tanh,
140+
torch.sigmoid,
141+
torch.exp,
142+
torch.expm1,
143+
torch.abs,
144+
torch.log,
145+
torch.log2,
146+
torch.log10,
147+
torch.log1p,
148+
torch.erf,
149+
torch.erfc,
150+
torch.sqrt,
151+
torch.rsqrt,
152+
torch.ceil,
153+
torch.floor,
154+
torch.round,
155+
torch.trunc,
156+
torch.lgamma,
157+
torch.ops.aten.sin,
158+
torch.ops.aten.cos,
159+
torch.ops.aten.tan,
160+
torch.ops.aten.asin,
161+
torch.ops.aten.acos,
162+
torch.ops.aten.atan,
163+
torch.ops.aten.sinh,
164+
torch.ops.aten.cosh,
165+
torch.ops.aten.tanh,
166+
torch.ops.aten.sigmoid,
167+
torch.ops.aten.exp,
168+
torch.ops.aten.expm1,
169+
torch.ops.aten.abs,
170+
torch.ops.aten.log,
171+
torch.ops.aten.log2,
172+
torch.ops.aten.log10,
173+
torch.ops.aten.log1p,
174+
torch.ops.aten.erf,
175+
torch.ops.aten.erfc,
176+
torch.ops.aten.sqrt,
177+
torch.ops.aten.rsqrt,
178+
torch.ops.aten.ceil,
179+
torch.ops.aten.floor,
180+
torch.ops.aten.round,
181+
torch.ops.aten.trunc,
182+
torch.ops.aten.lgamma,
183+
# TODO - Failure in generating loop nests here for the following ops
184+
# torch.frac,
185+
# torch.isnan,
186+
]
187+
188+
for unary_op in unary_operators:
189+
fn = lambda x: unary_op(x)
190+
pointwise_fn = pointwise_operator(fn)
191+
a = torch.rand(2, 3)
192+
ref = fn(a)
193+
res = pointwise_fn(a)
194+
assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
195+
196+
def test_binary_ops(self):
197+
binary_operators = [
198+
torch.add,
199+
torch.sub,
200+
torch.subtract,
201+
torch.mul,
202+
torch.multiply,
203+
torch.divide,
204+
torch.div,
205+
torch.fmod,
206+
torch.pow,
207+
torch.atan2,
208+
# torch.remainder, #TODO - Fails allclose check
209+
torch.ops.aten.add,
210+
torch.ops.aten.sub,
211+
torch.ops.aten.subtract,
212+
torch.ops.aten.mul,
213+
torch.ops.aten.multiply,
214+
torch.ops.aten.divide,
215+
torch.ops.aten.div,
216+
torch.ops.aten.fmod,
217+
torch.ops.aten.pow,
218+
torch.ops.aten.atan2,
219+
]
220+
for binary_op in binary_operators:
221+
fn = lambda x, y: binary_op(x, y)
222+
pointwise_fn = pointwise_operator(fn)
223+
a = torch.rand(2, 3)
224+
b = torch.rand(2, 3)
225+
ref = fn(a, b)
226+
res = pointwise_fn(a, b)
227+
assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
228+
229+
def test_bias_gelu(self):
230+
def bias_gelu(bias, y):
231+
x = bias + y
232+
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
233+
234+
bias = torch.rand(1, 768)
235+
y = torch.rand(64, 768)
236+
ref = bias_gelu(bias, y)
237+
238+
pointwise_fn = pointwise_operator(bias_gelu)
239+
res = pointwise_fn(bias, y)
240+
assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
241+
129242

130243
@unittest.skipIf(not HAS_CUDA, "GPU tests require CUDA")
131244
class TestOperatorAuthoringGPU(TestOperatorAuthoring):

0 commit comments

Comments
 (0)