Skip to content

Commit ca662dc

Browse files
[MLIR][TORCH] Add E2E support for aten.threshold, aten.threshold_backward op
This commit adds lowering of `aten.threshold` op This commit adds lowering of `aten.threshold_backward` op Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 7cf7b91 commit ca662dc

File tree

6 files changed

+392
-5
lines changed

6 files changed

+392
-5
lines changed

e2e_testing/torchscript/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from . import index_select
4848
from . import arange
4949
from . import constant_alloc
50+
from . import threshold
5051

5152
def _get_argparse():
5253
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
8+
from torch_mlir_e2e_test.torchscript.framework import TestUtils
9+
from torch_mlir_e2e_test.torchscript.registry import register_test_case
10+
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
11+
12+
# ==============================================================================
13+
14+
15+
class Threshold1dIntModule(torch.nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
19+
@export
20+
@annotate_args([
21+
None,
22+
([-1], torch.int64, True),
23+
])
24+
25+
def forward(self, input):
26+
return torch.ops.aten.threshold(input, 1, 2)
27+
28+
@register_test_case(module_factory=lambda: Threshold1dIntModule())
29+
def Threshold1dIntModule_basic(module, tu: TestUtils):
30+
module.forward(torch.randint(10, (4,)))
31+
32+
33+
class Threshold2dIntModule(torch.nn.Module):
34+
def __init__(self):
35+
super().__init__()
36+
37+
@export
38+
@annotate_args([
39+
None,
40+
([-1, -1], torch.int64, True),
41+
])
42+
43+
def forward(self, input):
44+
return torch.ops.aten.threshold(input, 0.5, 2)
45+
46+
@register_test_case(module_factory=lambda: Threshold2dIntModule())
47+
def Threshold2dIntModule_basic(module, tu: TestUtils):
48+
module.forward(torch.randint(10, (4, 5)))
49+
50+
51+
class Threshold3dIntModule(torch.nn.Module):
52+
def __init__(self):
53+
super().__init__()
54+
55+
@export
56+
@annotate_args([
57+
None,
58+
([-1, -1, -1], torch.int64, True),
59+
])
60+
61+
def forward(self, input):
62+
return torch.ops.aten.threshold(input, 1, 2.2)
63+
64+
@register_test_case(module_factory=lambda: Threshold3dIntModule())
65+
def Threshold3dIntModule_basic(module, tu: TestUtils):
66+
module.forward(torch.randint(10, (4, 5, 6)))
67+
68+
69+
class Threshold1dFloatModule(torch.nn.Module):
70+
def __init__(self):
71+
super().__init__()
72+
73+
@export
74+
@annotate_args([
75+
None,
76+
([-1], torch.float32, True),
77+
])
78+
79+
def forward(self, input):
80+
return torch.ops.aten.threshold(input, 1, 2)
81+
82+
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
83+
def Threshold1dFloatModule_basic(module, tu: TestUtils):
84+
module.forward(torch.randn(4))
85+
86+
87+
class Threshold2dFloatModule(torch.nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
91+
@export
92+
@annotate_args([
93+
None,
94+
([-1, -1], torch.float32, True),
95+
])
96+
97+
def forward(self, input):
98+
return torch.ops.aten.threshold(input, 0.5, 2)
99+
100+
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
101+
def Threshold2dFloatModule_basic(module, tu: TestUtils):
102+
module.forward(torch.randn(4, 5))
103+
104+
105+
class Threshold3dFloatModule(torch.nn.Module):
106+
def __init__(self):
107+
super().__init__()
108+
109+
@export
110+
@annotate_args([
111+
None,
112+
([-1, -1, -1], torch.float32, True),
113+
])
114+
115+
def forward(self, input):
116+
return torch.ops.aten.threshold(input, 1.4, 2.0)
117+
118+
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
119+
def Threshold3dFloatModule_basic(module, tu: TestUtils):
120+
module.forward(torch.randn(4, 5, 6))
121+
122+
123+
class ThresholdBackward1dIntModule(torch.nn.Module):
124+
def __init__(self):
125+
super().__init__()
126+
127+
@export
128+
@annotate_args([
129+
None,
130+
([-1], torch.int64, True),
131+
([-1], torch.int64, True),
132+
])
133+
134+
def forward(self, grad, input):
135+
return torch.ops.aten.threshold_backward(grad, input, 1)
136+
137+
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
138+
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
139+
module.forward(torch.randint(10, (4,)), torch.randint(8, (4,)))
140+
141+
142+
class ThresholdBackward2dIntModule(torch.nn.Module):
143+
def __init__(self):
144+
super().__init__()
145+
146+
@export
147+
@annotate_args([
148+
None,
149+
([-1, -1], torch.int64, True),
150+
([-1, -1], torch.int64, True),
151+
])
152+
153+
def forward(self, grad, input):
154+
return torch.ops.aten.threshold_backward(grad, input, 0.5)
155+
156+
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
157+
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
158+
module.forward(torch.randint(10, (4, 5)), torch.randint(8, (4, 5)))
159+
160+
161+
class ThresholdBackward3dIntModule(torch.nn.Module):
162+
def __init__(self):
163+
super().__init__()
164+
165+
@export
166+
@annotate_args([
167+
None,
168+
([-1, -1, -1], torch.int64, True),
169+
([-1, -1, -1], torch.int64, True),
170+
])
171+
172+
def forward(self, grad, input):
173+
return torch.ops.aten.threshold_backward(grad, input, 1)
174+
175+
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
176+
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
177+
module.forward(torch.randint(10, (4, 5, 6)), torch.randint(8, (4, 5, 6)))
178+
179+
180+
class ThresholdBackward1dFloatModule(torch.nn.Module):
181+
def __init__(self):
182+
super().__init__()
183+
184+
@export
185+
@annotate_args([
186+
None,
187+
([-1], torch.float32, True),
188+
([-1], torch.float32, True),
189+
])
190+
191+
def forward(self, grad, input):
192+
return torch.ops.aten.threshold_backward(grad, input, 1)
193+
194+
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
195+
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
196+
module.forward(torch.randn(4), torch.randn(4))
197+
198+
199+
class ThresholdBackward2dFloatModule(torch.nn.Module):
200+
def __init__(self):
201+
super().__init__()
202+
203+
@export
204+
@annotate_args([
205+
None,
206+
([-1, -1], torch.float32, True),
207+
([-1, -1], torch.float32, True),
208+
])
209+
210+
def forward(self, grad, input):
211+
return torch.ops.aten.threshold_backward(grad, input, 0.5)
212+
213+
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
214+
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
215+
module.forward(torch.randn(4, 5), torch.randn(4, 5))
216+
217+
218+
class ThresholdBackward3dFloatModule(torch.nn.Module):
219+
def __init__(self):
220+
super().__init__()
221+
222+
@export
223+
@annotate_args([
224+
None,
225+
([-1, -1, -1], torch.float32, True),
226+
([-1, -1, -1], torch.float32, True),
227+
])
228+
229+
def forward(self, grad, input):
230+
return torch.ops.aten.threshold_backward(grad, input, 1.4)
231+
232+
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
233+
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
234+
module.forward(torch.randn(4, 5, 6), torch.randn(4, 5, 6))
235+
236+
237+
class ThresholdBackward1dMixedModule(torch.nn.Module):
238+
def __init__(self):
239+
super().__init__()
240+
241+
@export
242+
@annotate_args([
243+
None,
244+
([-1], torch.float32, True),
245+
([-1], torch.int64, True),
246+
])
247+
248+
def forward(self, grad, input):
249+
return torch.ops.aten.threshold_backward(grad, input, 1)
250+
251+
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
252+
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
253+
module.forward(torch.randn(4), torch.randint(10, (4,)))
254+
255+
256+
class ThresholdBackward2dMixedModule(torch.nn.Module):
257+
def __init__(self):
258+
super().__init__()
259+
260+
@export
261+
@annotate_args([
262+
None,
263+
([-1, -1], torch.int64, True),
264+
([-1, -1], torch.float32, True),
265+
])
266+
267+
def forward(self, grad, input):
268+
return torch.ops.aten.threshold_backward(grad, input, 0.5)
269+
270+
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
271+
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
272+
module.forward(torch.randint(20, (4, 5)), torch.randn(4, 5))
273+
274+
275+
class ThresholdBackward3dMixedModule(torch.nn.Module):
276+
def __init__(self):
277+
super().__init__()
278+
279+
@export
280+
@annotate_args([
281+
None,
282+
([-1, -1, -1], torch.float32, True),
283+
([-1, -1, -1], torch.int64, True),
284+
])
285+
286+
def forward(self, grad, input):
287+
return torch.ops.aten.threshold_backward(grad, input, 1.4)
288+
289+
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
290+
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
291+
module.forward(torch.randn(4, 5, 6), torch.randint(10, (4, 5, 6)))

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,38 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [
11401140
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
11411141
}
11421142

1143+
def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [
1144+
AllowsTypeRefinement,
1145+
HasValueSemantics
1146+
]> {
1147+
let summary = "Generated op for `aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)`";
1148+
let arguments = (ins
1149+
AnyTorchTensorType:$self,
1150+
AnyTorchScalarType:$threshold,
1151+
AnyTorchScalarType:$value
1152+
);
1153+
let results = (outs
1154+
AnyTorchTensorType:$result
1155+
);
1156+
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)";
1157+
}
1158+
1159+
def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [
1160+
IsTrailingUnderscoreInplaceVariant,
1161+
AllowsTypeRefinement
1162+
]> {
1163+
let summary = "Generated op for `aten::threshold_ : (Tensor, Scalar, Scalar) -> (Tensor)`";
1164+
let arguments = (ins
1165+
AnyTorchTensorType:$self,
1166+
AnyTorchScalarType:$threshold,
1167+
AnyTorchScalarType:$value
1168+
);
1169+
let results = (outs
1170+
AnyTorchTensorType:$result
1171+
);
1172+
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)";
1173+
}
1174+
11431175
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
11441176
AllowsTypeRefinement,
11451177
HasValueSemantics
@@ -1249,6 +1281,22 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
12491281
let assemblyFormat = "$self `,` $exponent attr-dict `:` type($self) `,` type($exponent) `->` type($result)";
12501282
}
12511283

1284+
def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
1285+
AllowsTypeRefinement,
1286+
HasValueSemantics
1287+
]> {
1288+
let summary = "Generated op for `aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)`";
1289+
let arguments = (ins
1290+
AnyTorchTensorType:$grad_output,
1291+
AnyTorchTensorType:$self,
1292+
AnyTorchScalarType:$threshold
1293+
);
1294+
let results = (outs
1295+
AnyTorchTensorType:$result
1296+
);
1297+
let assemblyFormat = "$grad_output `,` $self `,` $threshold attr-dict `:` type($grad_output) `,` type($self) `,` type($threshold) `->` type($result)";
1298+
}
1299+
12521300
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
12531301
AllowsTypeRefinement,
12541302
HasValueSemantics

0 commit comments

Comments
 (0)