Skip to content

Commit 612a4a0

Browse files
tatwaichongFranklandJack
authored andcommitted
[mlir][tosa] Make TOSA MUL's Shift an Input
The TOSA-v1.0 specification makes the shift attribute of the MUL (Hammard product) operator an input. Move the `shift` parameter of the MUL operator in the MILR TOSA dialect from an attribute to an input and update any lit tests appropriately. Expand the verifier of the `tosa::MulOp` operation to check the various constraints defined in the TOSA-v1.0 specification. Specifically, ensure that all input operands (excluding the optional shift) are of the same rank. This means that broadcasting tests which previously checked rank-0 tensors would be broadcast are no longer valid and are removed. Signed-off-by: Jack Frankland <[email protected]>
1 parent 998bdae commit 612a4a0

File tree

14 files changed

+212
-83
lines changed

14 files changed

+212
-83
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
800800
let arguments = (ins
801801
Tosa_Tensor:$input1,
802802
Tosa_Tensor:$input2,
803-
I8Attr:$shift
803+
Optional<TosaTensorRankOf<[Tosa_Int8], [0]>>:$shift
804804
);
805805

806806
let results = (outs

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -90,43 +90,58 @@ static Value createLinalgBodyCalculationForElementwiseOp(
9090
}
9191

9292
// tosa::MulOp
93-
if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
94-
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
95-
96-
if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
97-
Value a = args[0];
98-
Value b = args[1];
99-
auto shift =
100-
cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
101-
if (shift > 0) {
102-
auto shiftConst =
103-
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
104-
if (!a.getType().isInteger(32))
105-
a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
106-
107-
if (!b.getType().isInteger(32))
108-
b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
109-
110-
auto result = rewriter.create<tosa::ApplyScaleOp>(
111-
loc, rewriter.getI32Type(), a, b, shiftConst,
112-
rewriter.getBoolAttr(false));
113-
114-
if (elementTy.isInteger(32))
115-
return result;
116-
117-
return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
93+
if (isa<tosa::MulOp>(op)) {
94+
auto shift_val = cast<tosa::MulOp>(op).getShift();
95+
if (!elementTy.isInteger(32) && shift_val.getImpl()) {
96+
(void)rewriter.notifyMatchFailure(op,
97+
"Cannot have shift value for non i32 output");
98+
return nullptr;
99+
};
100+
101+
if (isa<FloatType>(elementTy)) {
102+
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
118103
}
119104

120-
int aWidth = a.getType().getIntOrFloatBitWidth();
121-
int bWidth = b.getType().getIntOrFloatBitWidth();
122-
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
105+
if (isa<IntegerType>(elementTy)) {
106+
int32_t shift = 0;
107+
ElementsAttr shift_elem;
108+
if (shift_val.getImpl() && matchPattern(shift_val, m_Constant(&shift_elem))) {
109+
// Explicit shift is set.
110+
shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
111+
}
112+
113+
Value a = args[0];
114+
Value b = args[1];
115+
if (shift > 0) {
116+
auto shiftConst =
117+
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
118+
if (!a.getType().isInteger(32))
119+
a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
120+
121+
if (!b.getType().isInteger(32))
122+
b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
123+
124+
auto result = rewriter.create<tosa::ApplyScaleOp>(
125+
loc, rewriter.getI32Type(), a, b, shiftConst,
126+
rewriter.getBoolAttr(false));
123127

124-
if (aWidth < cWidth)
125-
a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
126-
if (bWidth < cWidth)
127-
b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
128+
if (elementTy.isInteger(32))
129+
return result;
128130

129-
return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
131+
return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
132+
}
133+
134+
int aWidth = a.getType().getIntOrFloatBitWidth();
135+
int bWidth = b.getType().getIntOrFloatBitWidth();
136+
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
137+
138+
if (aWidth < cWidth)
139+
a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
140+
if (bWidth < cWidth)
141+
b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
142+
143+
return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
144+
}
130145
}
131146

132147
// tosa::NegateOp
@@ -931,7 +946,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
931946
auto loc = operation->getLoc();
932947
auto rank =
933948
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
934-
auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
949+
// For the mul op we need to avoid expanding the rank of the optional shift
950+
// input.
951+
auto operandsToExpand =
952+
isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
953+
954+
auto expandedOperands =
955+
expandInputRanks(rewriter, loc, operandsToExpand, rank);
935956
auto [targetShape, masterOperands] =
936957
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
937958
auto broadcastOperands = broadcastDynamicDimensions(

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,18 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
614614
auto rhsAttr =
615615
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
616616

617-
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
617+
// Result right shift on i32_t data type only. For simplification, synthesize a zero
618+
// shift for other date type.
619+
int32_t shift = 0;
620+
if (resultETy.isInteger(32)) {
621+
ElementsAttr shift_elem;
622+
if (getShift().getImpl()) {
623+
if (!matchPattern(getShift(), m_Constant(&shift_elem)))
624+
// cannot be folded when the shift value is unknown.
625+
return {};
626+
shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
627+
}
628+
}
618629

619630
if (rhsTy == resultTy) {
620631
if (isSplatZero(resultETy, lhsAttr))
@@ -629,7 +640,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
629640
return lhs;
630641
}
631642

632-
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
643+
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
633644
}
634645

635646
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -866,9 +866,84 @@ LogicalResult tosa::SliceOp::verify() {
866866
}
867867

868868
LogicalResult tosa::MulOp::verify() {
869-
Type elementTy = getInput1().getType().getElementType();
870-
if (isa<FloatType>(elementTy) && getShift() != 0)
871-
return emitOpError() << "require shift to be 0 for float type";
869+
auto resElemType = getElementTypeOrSelf(getOutput());
870+
871+
// Verify if the element type amoung operands and result match tosa
872+
// specification.
873+
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
874+
IntegerType lhsIntType =
875+
cast<IntegerType>(getElementTypeOrSelf(getInput1()));
876+
IntegerType rhsIntType =
877+
cast<IntegerType>(getElementTypeOrSelf(getInput2()));
878+
if (lhsIntType != rhsIntType)
879+
return emitOpError(
880+
"requires the same element type for all operands");
881+
882+
// Though the spec requires the element type of result to be i32, a more
883+
// relaxed way is provided at dialect level for easier cooperating with
884+
// other dialects.
885+
if (lhsIntType.getWidth() > resIntType.getWidth())
886+
return emitOpError("invalid data type size for operands or result");
887+
888+
} else {
889+
// For other supported type, the spec requires requires the same element
890+
// type for all operands (excludes `shift` operand) and results.
891+
for (int i = 0; i < 2; ++i) {
892+
if (getElementTypeOrSelf(getOperand(i)) != resElemType)
893+
return emitOpError(
894+
"requires the same element type for all operands and results");
895+
}
896+
}
897+
898+
// Check if the shift value apply to non-i32 output type as that is not
899+
// allowed in the spec.
900+
if (!(llvm::isa<IntegerType>(resElemType) && resElemType.isInteger(32)))
901+
if (getShift().getImpl())
902+
return emitOpError(
903+
"right shift output only on i32 data type");
904+
905+
// Verify the op has same ranks for all main operands (excludes extra operands
906+
// such as shift of mul op, so this is the only difference with the built-in
907+
// `SameOperandsAndResultRank` trait) and results types, if known.
908+
909+
// delegate function that returns true if type is a shaped type with known
910+
// rank
911+
auto hasRank = [](const Type type) {
912+
if (auto shaped_type = dyn_cast<ShapedType>(type))
913+
return shaped_type.hasRank();
914+
915+
return false;
916+
};
917+
918+
auto rankedOperandTypes =
919+
llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
920+
921+
auto rankedResultTypes =
922+
llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
923+
924+
// If all operands and results are unranked, then no further verification.
925+
if (rankedOperandTypes.empty() && rankedResultTypes.empty())
926+
return success();
927+
928+
// delegate function that returns rank of shaped type with known rank
929+
auto getRank = [](const Type type) {
930+
return cast<ShapedType>(type).getRank();
931+
};
932+
933+
auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
934+
: getRank(*rankedResultTypes.begin());
935+
936+
for (size_t i = 0; i < 2; ++i) {
937+
if (rank != getRank(rankedOperandTypes[i])) {
938+
return emitOpError("operands don't have matching ranks");
939+
}
940+
}
941+
942+
for (const auto type : rankedResultTypes) {
943+
if (rank != getRank(type)) {
944+
return emitOpError("result type has different rank than operands");
945+
}
946+
}
872947

873948
return success();
874949
}

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
137137

138138
Value mulValue = rewriter
139139
.create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
140-
weight, /*shift=*/0)
140+
weight, Value{} /* zero_shift */)
141141
.getResult();
142142

143143
// Reshape output to [N, H, W, C * M].

mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
113113

114114
Value input1 = tosaBinaryOp.getInput1();
115115
Value input2 = tosaBinaryOp.getInput2();
116-
int32_t shift = tosaBinaryOp.getShift();
116+
Value shift = tosaBinaryOp.getShift();
117117
Value output = tosaBinaryOp.getResult();
118118
auto outputType = dyn_cast<RankedTensorType>(output.getType());
119119
if (!outputType)

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
451451

452452
// CHECK: linalg.generic
453453
// CHECK: arith.mulf
454-
%4 = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
454+
%4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
455455

456456
// CHECK: linalg.generic
457457
// CHECK: arith.negf
@@ -597,7 +597,7 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
597597
// CHECK: arith.extsi
598598
// CHECK: arith.extsi
599599
// CHECK: arith.muli
600-
%0 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
600+
%0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
601601

602602
return
603603
}
@@ -625,12 +625,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
625625

626626
// CHECK: linalg.generic
627627
// CHECK: arith.muli
628-
%2 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
628+
%shift1 = "tosa.const"() <{value = dense<0> : tensor<i8>}> : () -> tensor<i8>
629+
%2 = tosa.mul %arg0, %arg0, %shift1 : (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
629630

630631
// CHECK: linalg.generic
631632
// CHECK: arith.constant 2
632633
// CHECK: apply_scale
633-
%3 = tosa.mul %arg0, %arg0 {shift = 2 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
634+
%shift2 = "tosa.const"() <{value = dense<2> : tensor<i8>}> : () -> tensor<i8>
635+
%3 = tosa.mul %arg0, %arg0, %shift2: (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
634636

635637
// CHECK: linalg.generic
636638
// CHECK: arith.divsi

mlir/test/Dialect/Tosa/broadcast.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,6 @@ func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>)
169169
return %0 : tensor<3x3x4x5xf32>
170170
}
171171

172-
// -----
173-
// CHECK-LABEL: broadcast_mul
174-
func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
175-
// CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
176-
// CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
177-
%0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
178-
return %0 : tensor<17x16x15x14xi32>
179-
}
180-
181172
// -----
182173
// CHECK-LABEL: broadcast_arithmetic_right_shift
183174
func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
280280
// CHECK: return %arg0
281281
// CHECK-NOT: tosa.mul
282282
%ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
283-
%1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
283+
%1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
284284
return %1 : tensor<2x3xf32>
285285
}
286286

@@ -291,7 +291,7 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
291291
// CHECK: return %arg0
292292
// CHECK-NOT: tosa.mul
293293
%ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
294-
%1 = tosa.mul %ones, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
294+
%1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
295295
return %1 : tensor<2x3xf32>
296296
}
297297

@@ -302,7 +302,20 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
302302
// CHECK: return %arg0
303303
// CHECK-NOT: tosa.mul
304304
%ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
305-
%1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
305+
%1 = tosa.mul %arg0, %ones : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
306+
return %1 : tensor<2x3xi32>
307+
}
308+
309+
// -----
310+
311+
// CHECK-LABEL: @mul_one_int_and_shift
312+
func.func @mul_one_int_and_shift(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
313+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}>
314+
// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<31> : tensor<i8>}>
315+
// CHECK: %[[VAL_3:.*]] = tosa.mul %arg0, %[[VAL_1]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>)
316+
%ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
317+
%shift = "tosa.const"() <{value = dense<31> : tensor<i8>}> : () -> tensor<i8>
318+
%1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>) -> tensor<2x3xi32>
306319
return %1 : tensor<2x3xi32>
307320
}
308321

@@ -313,11 +326,11 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
313326
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
314327
// CHECK-NOT: tosa.mul
315328
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
316-
%1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
329+
%1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
317330

318331
// CHECK-NOT: tosa.mul
319332
// CHECK: return %[[ZERO]], %[[ZERO]]
320-
%2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
333+
%2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
321334
return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
322335
}
323336

@@ -872,7 +885,7 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
872885
// CHECK: tosa.mul
873886
%0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
874887
%1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
875-
%2 = tosa.mul %0, %1 { shift = 0 : i8} : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
888+
%2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
876889
return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
877890
}
878891

0 commit comments

Comments
 (0)