Skip to content

Commit 077e55d

Browse files
ljfitzcathyzhyi
authored andcommitted
Add support for constant_pad_nd
Note that to enable folding of the code coming from an example like the ConstantPad2dStaticModule e2e test, support for other operations had to be added/improved: - aten::neg.int - aten::eq.float - aten::eq.str - prim::Uninitialized
1 parent 35cf8d1 commit 077e55d

File tree

10 files changed

+391
-42
lines changed

10 files changed

+391
-42
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
*.swp
2+
.cache/
23
.vscode
34
.env
45
*.code-workspace

e2e_testing/torchscript/basic.py

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
# ==============================================================================
1313

14-
1514
class MmModule(torch.nn.Module):
1615
def __init__(self):
1716
super().__init__()
@@ -38,7 +37,6 @@ def MmModule_chained(module, tu: TestUtils):
3837

3938
# ==============================================================================
4039

41-
4240
class BmmModule(torch.nn.Module):
4341
def __init__(self):
4442
super().__init__()
@@ -57,10 +55,8 @@ def forward(self, lhs, rhs):
5755
def BmmModule_basic(module, tu: TestUtils):
5856
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
5957

60-
6158
# ==============================================================================
6259

63-
6460
# A subgraph with multiple mm ops.
6561
class MmDagModule(torch.nn.Module):
6662
def __init__(self):
@@ -80,10 +76,8 @@ def forward(self, lhs, rhs):
8076
def MmDagModule_basic(module, tu: TestUtils):
8177
module.forward(tu.rand(4, 4), tu.rand(4, 4))
8278

83-
8479
# ==============================================================================
8580

86-
8781
class MmTanhModule(torch.nn.Module):
8882
def __init__(self):
8983
super().__init__()
@@ -100,16 +94,13 @@ def forward(self, lhs, rhs):
10094
def matmul(self, lhs, rhs):
10195
return torch.mm(lhs, rhs)
10296

103-
# ==============================================================================
104-
10597

10698
@register_test_case(module_factory=lambda: MmTanhModule())
10799
def MmTanhModule_basic(module, tu: TestUtils):
108100
module.forward(tu.rand(4, 2), tu.rand(2, 4))
109101

110102
# ==============================================================================
111103

112-
113104
class AddmmModuleFloat(torch.nn.Module):
114105
def __init__(self):
115106
super().__init__()
@@ -196,7 +187,6 @@ def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
196187

197188
# ==============================================================================
198189

199-
200190
class FlattenStaticModule(torch.nn.Module):
201191
def __init__(self):
202192
super().__init__()
@@ -217,7 +207,6 @@ def FlattenStaticModule_basic(module, tu: TestUtils):
217207

218208
# ==============================================================================
219209

220-
221210
class FlattenRank0Module(torch.nn.Module):
222211
def __init__(self):
223212
super().__init__()
@@ -238,7 +227,6 @@ def FlattenRank0Module_basic(module, tu: TestUtils):
238227

239228
# ==============================================================================
240229

241-
242230
class FlattenDynamicModule(torch.nn.Module):
243231
def __init__(self):
244232
super().__init__()
@@ -259,7 +247,6 @@ def FlattenDynamicModule_basic(module, tu: TestUtils):
259247

260248
# ==============================================================================
261249

262-
263250
class MaxPool2dModule(torch.nn.Module):
264251
def __init__(self):
265252
super().__init__()
@@ -276,14 +263,86 @@ def __init__(self):
276263
def forward(self, x):
277264
return self.mp2d(x)
278265

279-
# ==============================================================================
280-
281266

282267
@register_test_case(module_factory=lambda: MaxPool2dModule())
283268
def MaxPool2dModule_basic(module, tu: TestUtils):
284269
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
285270

286271

272+
class ConstantPad2dStaticModule(torch.nn.Module):
273+
def __init__(self):
274+
super().__init__()
275+
self.pad2d = torch.nn.ConstantPad2d((0, 1, 2, 3), -float('inf'))
276+
277+
@export
278+
@annotate_args([
279+
None,
280+
([1, 1, 20, 20], torch.float32, True),
281+
])
282+
def forward(self, x):
283+
return self.pad2d(x)
284+
285+
286+
@register_test_case(module_factory=lambda: ConstantPad2dStaticModule())
287+
def ConstantPad2dStaticModule_basic(module, tu: TestUtils):
288+
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
289+
290+
# ==============================================================================
291+
292+
class ConstantPadNdModule(torch.nn.Module):
293+
def __init__(self):
294+
super().__init__()
295+
296+
@export
297+
@annotate_args([
298+
None,
299+
([-1, -1, -1, -1, -1, -1], torch.float32, True),
300+
])
301+
def forward(self, x):
302+
return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf'))
303+
304+
305+
@register_test_case(module_factory=lambda: ConstantPadNdModule())
306+
def ConstantPadNdModule_basic(module, tu: TestUtils):
307+
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
308+
309+
310+
class ConstantPadNdStaticModule(torch.nn.Module):
311+
def __init__(self):
312+
super().__init__()
313+
314+
@export
315+
@annotate_args([
316+
None,
317+
([1, 1, 20, 20, 4, 4], torch.float32, True),
318+
])
319+
def forward(self, x):
320+
return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf'))
321+
322+
323+
@register_test_case(module_factory=lambda: ConstantPadNdStaticModule())
324+
def ConstantPadNdStaticModule_basic(module, tu: TestUtils):
325+
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
326+
327+
class ConstantPadNdPartialStaticModule(torch.nn.Module):
328+
def __init__(self):
329+
super().__init__()
330+
331+
@export
332+
@annotate_args([
333+
None,
334+
([1, 1, 20, 20, -1, -1], torch.float32, True),
335+
])
336+
def forward(self, x):
337+
return torch.ops.aten.constant_pad_nd(x, (0, 1, 2, 3), -float('inf'))
338+
339+
340+
@register_test_case(module_factory=lambda: ConstantPadNdPartialStaticModule())
341+
def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils):
342+
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
343+
344+
# ==============================================================================
345+
287346
class TransposeIntModule(torch.nn.Module):
288347
def __init__(self):
289348
super().__init__()
@@ -296,13 +355,13 @@ def __init__(self):
296355
def forward(self, x):
297356
return torch.transpose(x, 0, 1)
298357

299-
# ==============================================================================
300-
301358

302359
@register_test_case(module_factory=lambda: TransposeIntModule())
303360
def TransposeIntModule_basic(module, tu: TestUtils):
304361
module.forward(tu.rand(3, 4, 2))
305362

363+
# ==============================================================================
364+
306365
class PermuteModule(torch.nn.Module):
307366
def __init__(self):
308367
super().__init__()
@@ -333,13 +392,12 @@ def __init__(self):
333392
def forward(self, x):
334393
return torch.transpose(x, -1, -2)
335394

336-
# ==============================================================================
337-
338395

339396
@register_test_case(module_factory=lambda: TransposeIntNegDimsModule())
340397
def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
341398
module.forward(tu.rand(3, 4, 2))
342399

400+
# ==============================================================================
343401

344402
class PermuteNegativeIndexModule(torch.nn.Module):
345403
def __init__(self):
@@ -353,11 +411,12 @@ def __init__(self):
353411
def forward(self, x):
354412
return x.permute(0, -1, 1)
355413

356-
# ==============================================================================
357-
358414
@register_test_case(module_factory=lambda: PermuteNegativeIndexModule())
359415
def PermuteNegativeIndexModule_basic(module, tu: TestUtils):
360416
module.forward(tu.rand(3, 4, 2))
417+
418+
# ==============================================================================
419+
361420
class TensorsConcatModule(torch.nn.Module):
362421
def __init__(self):
363422
super().__init__()
@@ -379,7 +438,6 @@ def TensorsConcatModule_basic(module, tu: TestUtils):
379438

380439
# ==============================================================================
381440

382-
383441
class GatherModule(torch.nn.Module):
384442
def __init__(self):
385443
super().__init__()
@@ -422,7 +480,6 @@ def AddSizeIntModule_basic(module, tu: TestUtils):
422480

423481
# ==============================================================================
424482

425-
426483
class AddSizeIntNegDimModule(torch.nn.Module):
427484
def __init__(self):
428485
super().__init__()
@@ -505,7 +562,6 @@ def _SoftmaxModule_basic(module, tu: TestUtils):
505562

506563
# ==============================================================================
507564

508-
509565
class SoftmaxIntNegDimModule(torch.nn.Module):
510566
def __init__(self):
511567
super().__init__()
@@ -527,7 +583,6 @@ def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
527583

528584
# ==============================================================================
529585

530-
531586
class SoftmaxIntArgTypeF64Module(torch.nn.Module):
532587
def __init__(self):
533588
super().__init__()

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,6 +1778,22 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
17781778
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
17791779
}
17801780

1781+
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
1782+
AllowsTypeRefinement,
1783+
HasValueSemantics
1784+
]> {
1785+
let summary = "Generated op for `aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)`";
1786+
let arguments = (ins
1787+
AnyTorchTensorType:$self,
1788+
TorchIntListType:$pad,
1789+
AnyTorchScalarType:$value
1790+
);
1791+
let results = (outs
1792+
AnyTorchTensorType:$result
1793+
);
1794+
let assemblyFormat = "$self `,` $pad `,` $value attr-dict `:` type($self) `,` type($pad) `,` type($value) `->` type($result)";
1795+
}
1796+
17811797
def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [
17821798
AllowsTypeRefinement
17831799
]> {
@@ -2915,6 +2931,22 @@ def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [
29152931
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
29162932
}
29172933

2934+
def Torch_AtenEqStrOp : Torch_Op<"aten.eq.str", [
2935+
AllowsTypeRefinement,
2936+
HasValueSemantics
2937+
]> {
2938+
let summary = "Generated op for `aten::eq.str : (str, str) -> (bool)`";
2939+
let arguments = (ins
2940+
Torch_StringType:$a,
2941+
Torch_StringType:$b
2942+
);
2943+
let results = (outs
2944+
Torch_BoolType:$result
2945+
);
2946+
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
2947+
let hasFolder = 1;
2948+
}
2949+
29182950
def Torch_AtenStrOp : Torch_Op<"aten.str", [
29192951
AllowsTypeRefinement,
29202952
HasValueSemantics
@@ -3175,6 +3207,21 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
31753207
let hasFolder = 1;
31763208
}
31773209

3210+
def Torch_AtenNegIntOp : Torch_Op<"aten.neg.int", [
3211+
AllowsTypeRefinement,
3212+
HasValueSemantics
3213+
]> {
3214+
let summary = "Generated op for `aten::neg.int : (int) -> (int)`";
3215+
let arguments = (ins
3216+
Torch_IntType:$a
3217+
);
3218+
let results = (outs
3219+
Torch_IntType:$result
3220+
);
3221+
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
3222+
let hasFolder = 1;
3223+
}
3224+
31783225
def Torch_AtenLogIntOp : Torch_Op<"aten.log.int", [
31793226
AllowsTypeRefinement,
31803227
HasValueSemantics
@@ -3248,6 +3295,22 @@ def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [
32483295
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
32493296
}
32503297

3298+
def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [
3299+
AllowsTypeRefinement,
3300+
HasValueSemantics
3301+
]> {
3302+
let summary = "Generated op for `aten::eq.float : (float, float) -> (bool)`";
3303+
let arguments = (ins
3304+
Torch_FloatType:$a,
3305+
Torch_FloatType:$b
3306+
);
3307+
let results = (outs
3308+
Torch_BoolType:$result
3309+
);
3310+
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
3311+
let hasFolder = 1;
3312+
}
3313+
32513314
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
32523315
AllowsTypeRefinement,
32533316
HasValueSemantics

include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
185185
AnyTorchType:$result
186186
);
187187
let assemblyFormat = " attr-dict `:` type($result)";
188+
let hasCanonicalizer = 1;
188189
}
189190

190191
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [

include/torch-mlir/Dialect/Torch/IR/TorchOps.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,21 @@ struct torch_constant_int_op_binder {
4545
return false;
4646
}
4747
};
48+
49+
struct torch_constant_float_op_binder {
50+
double *bind_value;
51+
52+
/// Creates a matcher instance that binds the value to bv if match succeeds.
53+
torch_constant_float_op_binder(double *bv) : bind_value(bv) {}
54+
55+
bool match(Operation *op) {
56+
if (auto constantFloat = dyn_cast<Torch::ConstantFloatOp>(op)) {
57+
*bind_value = constantFloat.value().convertToDouble();
58+
return true;
59+
}
60+
return false;
61+
}
62+
};
4863
} // namespace detail
4964

5065
/// Matches the integer stored in a `torch.constant.bool`.
@@ -53,6 +68,12 @@ m_TorchConstantInt(int64_t *bind_value) {
5368
return detail::torch_constant_int_op_binder(bind_value);
5469
}
5570

71+
/// Matches the float value stored in a `torch.constant.float`.
72+
inline detail::torch_constant_float_op_binder
73+
m_TorchConstantFloat(double *bind_value) {
74+
return detail::torch_constant_float_op_binder(bind_value);
75+
}
76+
5677
namespace detail {
5778
/// Matches the bool stored in a `torch.constant.bool`.
5879
struct torch_constant_bool_op_binder {

0 commit comments

Comments
 (0)