Skip to content

Commit dc9ea08

Browse files
[MLIR][ONNX] Add OnnxToTorch support for atan and bitwise ops
This commit adds the OnnxToTorch support for Atan, Bitshift, BitwiseAnd, and BitwiseNot op. This commit also adds the TorchToLinalg support for AtenBitwiseLeftShiftTensorOp. Signed-Off By: [email protected]
1 parent 53fc995 commit dc9ea08

File tree

9 files changed

+337
-21
lines changed

9 files changed

+337
-21
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/ADT/DenseMap.h"
1818
#include "llvm/ADT/SmallString.h"
1919
#include "llvm/ADT/SmallVector.h"
20+
#include <string>
2021

2122
namespace mlir::torch::onnx_c {
2223

@@ -103,6 +104,22 @@ struct OpBinder {
103104
return failure();
104105
}
105106

107+
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
108+
std::string defaultValue = "") {
109+
SmallString<64> name("torch.onnx.");
110+
name.append(nameSuffix);
111+
auto attr = op->getAttr(name);
112+
if (!attr) {
113+
value = defaultValue;
114+
return success();
115+
}
116+
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
117+
value = stringAttr.str();
118+
return success();
119+
}
120+
return failure();
121+
}
122+
106123
Torch::ValueTensorType toValidTensorType(Type t) {
107124
auto tt = dyn_cast<Torch::ValueTensorType>(t);
108125
if (tt && tt.hasSizes())

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2891,6 +2891,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [
28912891
}];
28922892
}
28932893

2894+
def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [
2895+
AllowsTypeRefinement,
2896+
HasValueSemantics,
2897+
ReadOnly
2898+
]> {
2899+
let summary = "Generated op for `aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)`";
2900+
let arguments = (ins
2901+
AnyTorchTensorType:$self,
2902+
AnyTorchTensorType:$other
2903+
);
2904+
let results = (outs
2905+
AnyTorchTensorType:$result
2906+
);
2907+
let hasCustomAssemblyFormat = 1;
2908+
let extraClassDefinition = [{
2909+
ParseResult AtenBitwiseLeftShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) {
2910+
return parseDefaultTorchOp(parser, result, 2, 1);
2911+
}
2912+
void AtenBitwiseLeftShiftTensorOp::print(OpAsmPrinter &printer) {
2913+
printDefaultTorchOp(printer, *this, 2, 1);
2914+
}
2915+
}];
2916+
}
2917+
2918+
def Torch_AtenBitwiseLeftShift_TensorOp : Torch_Op<"aten.bitwise_left_shift_.Tensor", [
2919+
IsTrailingUnderscoreInplaceVariant,
2920+
AllowsTypeRefinement
2921+
]> {
2922+
let summary = "Generated op for `aten::bitwise_left_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`";
2923+
let arguments = (ins
2924+
Torch_NonValueTensorType:$self,
2925+
Torch_NonValueTensorType:$other
2926+
);
2927+
let results = (outs
2928+
Torch_NonValueTensorType:$result
2929+
);
2930+
let hasCustomAssemblyFormat = 1;
2931+
let extraClassDefinition = [{
2932+
ParseResult AtenBitwiseLeftShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
2933+
return parseDefaultTorchOp(parser, result, 2, 1);
2934+
}
2935+
void AtenBitwiseLeftShift_TensorOp::print(OpAsmPrinter &printer) {
2936+
printDefaultTorchOp(printer, *this, 2, 1);
2937+
}
2938+
}];
2939+
}
2940+
28942941
def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [
28952942
AllowsTypeRefinement,
28962943
HasValueSemantics,

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
11+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1112

1213
using namespace mlir;
1314
using namespace mlir::torch;
@@ -141,6 +142,57 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
141142
});
142143
// TODO: Asin unimplemented in torch-mlir
143144
// TODO: Asinh unimplemented in torch-mlir
144-
// TODO: Atan unimplemented in torch-mlir
145145
// TODO: Atanh unimplemented in torch-mlir
146+
patterns.onOp("Atan", 7,
147+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
148+
Torch::ValueTensorType resultType;
149+
Value operand;
150+
if (binder.tensorOperand(operand) ||
151+
binder.tensorResultType(resultType))
152+
return failure();
153+
rewriter.replaceOpWithNewOp<Torch::AtenAtanOp>(
154+
binder.op, resultType, operand);
155+
return success();
156+
});
157+
patterns.onOp(
158+
"BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
159+
Torch::ValueTensorType resultType;
160+
Value lhs, rhs;
161+
std::string direction;
162+
if (binder.tensorOperands(lhs, rhs) ||
163+
binder.tensorResultType(resultType) ||
164+
binder.customOpNameStringAttr(direction, "direction", ""))
165+
return failure();
166+
if (direction == "LEFT") {
167+
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseLeftShiftTensorOp>(
168+
binder.op, resultType, lhs, rhs);
169+
} else {
170+
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseRightShiftTensorOp>(
171+
binder.op, resultType, lhs, rhs);
172+
}
173+
return success();
174+
});
175+
patterns.onOp(
176+
"BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
177+
Torch::ValueTensorType resultType;
178+
Value lhs, rhs;
179+
std::string direction;
180+
if (binder.tensorOperands(lhs, rhs) ||
181+
binder.tensorResultType(resultType))
182+
return failure();
183+
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
184+
binder.op, resultType, lhs, rhs);
185+
return success();
186+
});
187+
patterns.onOp("BitwiseNot", 18,
188+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
189+
Torch::ValueTensorType resultType;
190+
Value operand;
191+
if (binder.tensorOperand(operand) ||
192+
binder.tensorResultType(resultType))
193+
return failure();
194+
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseNotOp>(
195+
binder.op, resultType, operand);
196+
return success();
197+
});
146198
}

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
366366
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
367367
return b.create<arith::ShRSIOp>(loc, lhs, rhs);
368368
}
369+
if (auto bitwiseLeftShiftTensor =
370+
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
371+
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
372+
.cast<RankedTensorType>()
373+
.getElementType();
374+
if (!dtype.isa<mlir::IntegerType>()) {
375+
bitwiseLeftShiftTensor.emitError(
376+
"Bitwise_Left_Shift op does not support non-integer input dtype.");
377+
return nullptr;
378+
}
379+
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
380+
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
381+
return b.create<arith::ShLIOp>(loc, lhs, rhs);
382+
}
369383
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
370384
MLIRContext *context = op->getContext();
371385
Type floatDtype = mlir::FloatType::getF64(context);
@@ -1252,16 +1266,17 @@ class ConvertElementwiseOp : public ConversionPattern {
12521266
AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
12531267
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
12541268
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
1255-
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
1256-
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
1257-
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
1258-
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
1259-
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
1260-
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
1261-
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
1262-
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp,
1263-
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
1264-
AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op))
1269+
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
1270+
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
1271+
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
1272+
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp,
1273+
AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
1274+
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
1275+
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
1276+
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
1277+
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
1278+
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
1279+
AtenAtanOp, AtenRealOp, AtenImagOp>(op))
12651280
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
12661281

12671282
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@@ -1900,16 +1915,16 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
19001915
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
19011916
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
19021917
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
1903-
AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp,
1904-
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
1905-
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
1906-
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp,
1907-
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp,
1908-
AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
1909-
AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp,
1910-
AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp,
1911-
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp,
1912-
AtenImagOp>();
1918+
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
1919+
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
1920+
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
1921+
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
1922+
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
1923+
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
1924+
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
1925+
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
1926+
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
1927+
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
19131928
patterns.add<ConvertElementwiseOp>(typeConverter, context);
19141929
target.addIllegalOp<AtenNllLossForwardOp>();
19151930
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7680,6 +7680,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
76807680
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
76817681
" return %0 : !torch.list<int>\n"
76827682
" }\n"
7683+
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
7684+
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
7685+
" return %0 : !torch.list<int>\n"
7686+
" }\n"
76837687
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
76847688
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
76857689
" return %0 : !torch.list<int>\n"
@@ -9560,6 +9564,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
95609564
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
95619565
" return %4 : !torch.int\n"
95629566
" }\n"
9567+
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
9568+
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9569+
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9570+
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
9571+
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9572+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
9573+
" return %4 : !torch.int\n"
9574+
" }\n"
95639575
" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
95649576
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
95659577
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,9 @@ def aten〇bitwise_xor〇Tensor〡shape(self: List[int], other: List[int]) -> Li
908908
def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
909909
return upstream_shape_functions.broadcast(self, other)
910910

911+
def aten〇bitwise_left_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
912+
return upstream_shape_functions.broadcast(self, other)
913+
911914
def aten〇bitwise_not〡shape(self: List[int]) -> List[int]:
912915
return upstream_shape_functions.unary(self)
913916

@@ -2454,6 +2457,14 @@ def aten〇bitwise_right_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int]
24542457
dtypes = [self_dtype, other_dtype]
24552458
return promote_dtypes(ranks, dtypes)
24562459

2460+
@check_dtype_function(_check_two_tensor_op())
2461+
def aten〇bitwise_left_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
2462+
other_rank, other_dtype = other_rank_dtype
2463+
self_rank, self_dtype = self_rank_dtype
2464+
ranks: List[Optional[int]] = [self_rank, other_rank]
2465+
dtypes = [self_dtype, other_dtype]
2466+
return promote_dtypes(ranks, dtypes)
2467+
24572468
@check_dtype_function(
24582469
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
24592470
# Different width

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def emit_with_mutating_variants(key, **kwargs):
317317
"aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)",
318318
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
319319
"aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)",
320+
"aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
320321
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
321322
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
322323
"aten::square : (Tensor) -> (Tensor)",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3649,6 +3649,69 @@ def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils):
36493649
# ==============================================================================
36503650

36513651

3652+
class ElementwiseBitwiseLeftShiftInt64Module(torch.nn.Module):
3653+
3654+
def __init__(self):
3655+
super().__init__()
3656+
3657+
@export
3658+
@annotate_args([
3659+
None,
3660+
([-1, -1], torch.int64, True),
3661+
([-1, -1], torch.int64, True),
3662+
])
3663+
def forward(self, lhs, rhs):
3664+
return torch.bitwise_left_shift(lhs, rhs)
3665+
3666+
3667+
@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt64Module())
3668+
def ElementwiseBitwiseLeftShiftInt64Module_basic(module, tu: TestUtils):
3669+
module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64))
3670+
3671+
3672+
class ElementwiseBitwiseLeftShiftInt32Module(torch.nn.Module):
3673+
3674+
def __init__(self):
3675+
super().__init__()
3676+
3677+
@export
3678+
@annotate_args([
3679+
None,
3680+
([-1, 4], torch.int32, True),
3681+
([-1, 1], torch.int32, True),
3682+
])
3683+
def forward(self, lhs, rhs):
3684+
return torch.bitwise_left_shift(lhs, rhs)
3685+
3686+
3687+
@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt32Module())
3688+
def ElementwiseBitwiseLeftShiftInt32Module_basic(module, tu: TestUtils):
3689+
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32))
3690+
3691+
3692+
class ElementwiseBitwiseLeftShiftInt8Module(torch.nn.Module):
3693+
3694+
def __init__(self):
3695+
super().__init__()
3696+
3697+
@export
3698+
@annotate_args([
3699+
None,
3700+
([-1, -1], torch.int8, True),
3701+
([-1, -1], torch.int8, True),
3702+
])
3703+
def forward(self, lhs, rhs):
3704+
return torch.bitwise_left_shift(lhs, rhs)
3705+
3706+
3707+
@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt8Module())
3708+
def ElementwiseBitwiseLeftShiftInt8Module_basic(module, tu: TestUtils):
3709+
module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8))
3710+
3711+
3712+
# ==============================================================================
3713+
3714+
36523715
class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module):
36533716

36543717
def __init__(self):

0 commit comments

Comments
 (0)