Skip to content

Commit a6c3050

Browse files
Anup Gangwarsilvasean
authored andcommitted
* [tosa] Support for Maximum and Minimum
Signed-off-by: Anup Gangwar <[email protected]>
1 parent 707c113 commit a6c3050

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,39 @@ class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
7777
}
7878
};
7979

80+
// These binary op legalizations are identical for floating-point
81+
// or quantized types
82+
template <typename AtenOpT, typename TosaOpT>
83+
class ConvertAtenBinaryOp : public OpConversionPattern<AtenOpT> {
84+
public:
85+
using OpConversionPattern<AtenOpT>::OpConversionPattern;
86+
using OpAdaptor = typename AtenOpT::Adaptor;
87+
LogicalResult
88+
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
89+
ConversionPatternRewriter &rewriter) const override {
90+
Value lhs = adaptor.self();
91+
auto lhsTy = lhs.getType().cast<TensorType>();
92+
Value rhs = adaptor.other();
93+
auto rhsTy = rhs.getType().cast<TensorType>();
94+
95+
if (!lhsTy || !rhsTy)
96+
return op.emitError("Only Tensor types supported in TOSA");
97+
98+
auto lhsElemTy = lhsTy.getElementType();
99+
auto rhsElemTy = rhsTy.getElementType();
100+
101+
if (lhsElemTy != rhsElemTy)
102+
return op.emitError("Add: input datatypes mismatched");
103+
104+
rewriter.replaceOpWithNewOp<TosaOpT>(
105+
op,
106+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
107+
op.getType()),
108+
lhs, rhs);
109+
return success();
110+
}
111+
};
112+
80113
// These binary op legalizations are specific to add/sub which have an
81114
// alpha multiplier.
82115
template <typename AtenOpT, typename TosaOpT>
@@ -538,6 +571,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
538571
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
539572
#undef INSERT_UNARY_PATTERN
540573

574+
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
575+
target.addIllegalOp<AtenOp>(); \
576+
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
577+
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
578+
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
579+
#undef INSERT_BINARY_PATTERN
580+
541581
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
542582
target.addIllegalOp<AtenOp>(); \
543583
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,35 @@ func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
298298
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
299299
return %0 : !torch.vtensor<[?,?],f32>
300300
}
301+
302+
// -----
303+
304+
// CHECK-LABEL: func @torch.aten.maximum$basic(
305+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
306+
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
307+
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
308+
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
309+
// CHECK: %[[VAL_4:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
310+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
311+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
312+
// CHECK: }
313+
func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
314+
%0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
315+
return %0 : !torch.vtensor<[?,?],f32>
316+
}
317+
318+
// -----
319+
320+
// CHECK-LABEL: func @torch.aten.minimum$basic(
321+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
322+
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
323+
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
324+
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
325+
// CHECK: %[[VAL_4:.*]] = "tosa.minimum"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
326+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
327+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
328+
// CHECK: }
329+
func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
330+
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
331+
return %0 : !torch.vtensor<[?,?],f32>
332+
}

0 commit comments

Comments
 (0)