Skip to content

Commit fb07dc6

Browse files
krzysz00krishna2803
authored andcommitted
[mlir][Vector] Allow elementwise/broadcast swap to handle mixed types (llvm#151274)
This patch extends the operation that rewrites elementwise operations whose inputs are all broadcast from the same shape to handle mixed-types, such as when the result and input types don't match, or when the inputs have multiple types. PR llvm#150867 failed to check for the possibility of type mismatches when rewriting splat constants. In order to fix that issue, we add support for mixed-type operations more generally.
1 parent f511445 commit fb07dc6

File tree

2 files changed

+160
-79
lines changed

2 files changed

+160
-79
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,28 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
965965
std::function<bool(BitCastOp)> controlFn;
966966
};
967967

968+
static bool haveSameShapeAndScaling(Type t, Type u) {
969+
auto tVec = dyn_cast<VectorType>(t);
970+
auto uVec = dyn_cast<VectorType>(u);
971+
if (!tVec) {
972+
return !uVec;
973+
}
974+
if (!uVec) {
975+
return false;
976+
}
977+
return tVec.getShape() == uVec.getShape() &&
978+
tVec.getScalableDims() == uVec.getScalableDims();
979+
}
980+
981+
/// If `type` is shaped, clone it with `newElementType`. Otherwise,
982+
/// return `newElementType`.
983+
static Type cloneOrReplace(Type type, Type newElementType) {
984+
if (auto shapedType = dyn_cast<ShapedType>(type)) {
985+
return shapedType.clone(newElementType);
986+
}
987+
return newElementType;
988+
}
989+
968990
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
969991
///
970992
/// Example:
@@ -988,23 +1010,22 @@ struct ReorderElementwiseOpsOnBroadcast final
9881010
PatternRewriter &rewriter) const override {
9891011
if (op->getNumResults() != 1)
9901012
return failure();
991-
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
1013+
auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
1014+
if (!resultType)
9921015
return failure();
9931016
if (!OpTrait::hasElementwiseMappableTraits(op))
9941017
return rewriter.notifyMatchFailure(
9951018
op, "Op doesn't have ElementwiseMappableTraits");
9961019
if (op->getNumOperands() == 0)
9971020
return failure();
998-
if (op->getResults()[0].getType() != op->getOperand(0).getType())
999-
return rewriter.notifyMatchFailure(op,
1000-
"result and operand type mismatch");
10011021
if (isa<vector::FMAOp>(op)) {
10021022
return rewriter.notifyMatchFailure(
10031023
op,
10041024
"Op only accepts vector types - not supported as broadcast source "
10051025
"might be a scalar");
10061026
}
10071027

1028+
Type resultElemType = resultType.getElementType();
10081029
// Get the type of the first non-constant operand
10091030
Operation *firstBroadcastOrSplat = nullptr;
10101031
for (Value operand : op->getOperands()) {
@@ -1020,24 +1041,23 @@ struct ReorderElementwiseOpsOnBroadcast final
10201041
}
10211042
if (!firstBroadcastOrSplat)
10221043
return failure();
1023-
Type firstBroadcastOrSplatType =
1024-
firstBroadcastOrSplat->getOperand(0).getType();
1044+
Type unbroadcastResultType = cloneOrReplace(
1045+
firstBroadcastOrSplat->getOperand(0).getType(), resultElemType);
10251046

1026-
// Make sure that all operands are broadcast from identical types:
1047+
// Make sure that all operands are broadcast from identically-shaped types:
10271048
// * scalar (`vector.broadcast` + `vector.splat`), or
10281049
// * vector (`vector.broadcast`).
10291050
// Otherwise the re-ordering wouldn't be safe.
1030-
if (!llvm::all_of(
1031-
op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
1032-
if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
1033-
return (bcastOp.getOperand().getType() ==
1034-
firstBroadcastOrSplatType);
1035-
if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
1036-
return (splatOp.getOperand().getType() ==
1037-
firstBroadcastOrSplatType);
1038-
SplatElementsAttr splatConst;
1039-
return matchPattern(val, m_Constant(&splatConst));
1040-
})) {
1051+
if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) {
1052+
if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
1053+
return haveSameShapeAndScaling(bcastOp.getOperand().getType(),
1054+
unbroadcastResultType);
1055+
if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
1056+
return haveSameShapeAndScaling(splatOp.getOperand().getType(),
1057+
unbroadcastResultType);
1058+
SplatElementsAttr splatConst;
1059+
return matchPattern(val, m_Constant(&splatConst));
1060+
})) {
10411061
return failure();
10421062
}
10431063

@@ -1048,15 +1068,16 @@ struct ReorderElementwiseOpsOnBroadcast final
10481068
SplatElementsAttr splatConst;
10491069
if (matchPattern(operand, m_Constant(&splatConst))) {
10501070
Attribute newConst;
1051-
if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
1052-
newConst = splatConst.resizeSplat(shapedTy);
1071+
Type elementType = getElementTypeOrSelf(operand.getType());
1072+
Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1073+
if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1074+
newConst = splatConst.resizeSplat(newTypeShaped);
10531075
} else {
10541076
newConst = splatConst.getSplatValue<Attribute>();
10551077
}
10561078
Operation *newConstOp =
10571079
operand.getDefiningOp()->getDialect()->materializeConstant(
1058-
rewriter, newConst, firstBroadcastOrSplatType,
1059-
operand.getLoc());
1080+
rewriter, newConst, newType, operand.getLoc());
10601081
srcValues.push_back(newConstOp->getResult(0));
10611082
} else {
10621083
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
@@ -1066,12 +1087,11 @@ struct ReorderElementwiseOpsOnBroadcast final
10661087
// Create the "elementwise" Op
10671088
Operation *elementwiseOp =
10681089
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1069-
firstBroadcastOrSplatType, op->getAttrs());
1090+
unbroadcastResultType, op->getAttrs());
10701091

10711092
// Replace the original Op with the elementwise Op
1072-
auto vectorType = op->getResultTypes()[0];
10731093
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1074-
op, vectorType, elementwiseOp->getResults());
1094+
op, resultType, elementwiseOp->getResults());
10751095

10761096
return success();
10771097
}

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 115 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> {
180180

181181
// -----
182182

183-
// The source and the result for arith.cmp have different types - not supported
184-
185-
// CHECK-LABEL: func.func @negative_source_and_result_mismatch
186-
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
187-
// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
188-
// CHECK: return %[[RETURN]]
189-
func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
183+
// The source and the result for arith.cmp have different types
184+
185+
// CHECK-LABEL: func.func @source_and_result_mismatch(
186+
// CHECK-SAME: %[[ARG0:.+]]: f32)
187+
// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]]
188+
// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1>
189+
// CHECK: return %[[BROADCAST]]
190+
func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> {
190191
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
191192
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
192193
return %1 : vector<1xi1>
@@ -210,53 +211,6 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
210211
return %1 : vector<1xf32>
211212
}
212213

213-
//===----------------------------------------------------------------------===//
214-
// [Pattern: ReorderCastOpsOnBroadcast]
215-
//
216-
// Reorder casting ops and vector ops. The casting ops have almost identical
217-
// pattern, so only arith.extsi op is tested.
218-
//===----------------------------------------------------------------------===//
219-
220-
// -----
221-
222-
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
223-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
224-
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
225-
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
226-
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
227-
return %r : vector<2x4xi32>
228-
}
229-
230-
// -----
231-
232-
func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> {
233-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
234-
// CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
235-
%b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8>
236-
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
237-
return %r : vector<2x[4]xi32>
238-
}
239-
240-
// -----
241-
242-
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
243-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
244-
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
245-
%b = vector.broadcast %a : i8 to vector<2x4xi8>
246-
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
247-
return %r : vector<2x4xi32>
248-
}
249-
250-
// -----
251-
252-
func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
253-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
254-
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
255-
%b = vector.broadcast %a : i8 to vector<2x[4]xi8>
256-
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
257-
return %r : vector<2x[4]xi32>
258-
}
259-
260214
// -----
261215

262216
// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const(
@@ -321,6 +275,113 @@ func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xi
321275
return %2 : vector<1x4xindex>
322276
}
323277

278+
// -----
279+
280+
// CHECK-LABEL: func.func @broadcast_scalar_mixed_type(
281+
// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> {
282+
// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32
283+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32>
284+
// CHECK: return %[[BCAST]] : vector<1x4xf32>
285+
286+
func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> {
287+
%0 = vector.broadcast %arg0 : f16 to vector<1x4xf16>
288+
%1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32>
289+
return %1 : vector<1x4xf32>
290+
}
291+
292+
// -----
293+
294+
// CHECK-LABEL: func.func @broadcast_vector_mixed_type(
295+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> {
296+
// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32>
297+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32>
298+
// CHECK: return %[[BCAST]] : vector<3x4xf32>
299+
300+
func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> {
301+
%0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16>
302+
%1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32>
303+
return %1 : vector<3x4xf32>
304+
}
305+
306+
// -----
307+
308+
// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type(
309+
// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> {
310+
// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32
311+
// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32
312+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32>
313+
// CHECK: return %[[BCAST]] : vector<1x4xf32>
314+
315+
func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> {
316+
%0 = vector.broadcast %arg0 : f32 to vector<1x4xf32>
317+
%cst = arith.constant dense<3> : vector<1x4xi32>
318+
%2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32>
319+
return %2 : vector<1x4xf32>
320+
}
321+
322+
// -----
323+
324+
// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type(
325+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
326+
// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32>
327+
// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32>
328+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32>
329+
// CHECK: return %[[BCAST]] : vector<3x4xf32>
330+
331+
func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> {
332+
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
333+
%cst = arith.constant dense<3> : vector<3x4xi32>
334+
%2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32>
335+
return %2 : vector<3x4xf32>
336+
}
337+
338+
//===----------------------------------------------------------------------===//
339+
// [Pattern: ReorderCastOpsOnBroadcast]
340+
//
341+
// Reorder casting ops and vector ops. The casting ops have almost identical
342+
// pattern, so only arith.extsi op is tested.
343+
//===----------------------------------------------------------------------===//
344+
345+
// -----
346+
347+
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
348+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
349+
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
350+
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
351+
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
352+
return %r : vector<2x4xi32>
353+
}
354+
355+
// -----
356+
357+
func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> {
358+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
359+
// CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
360+
%b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8>
361+
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
362+
return %r : vector<2x[4]xi32>
363+
}
364+
365+
// -----
366+
367+
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
368+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
369+
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
370+
%b = vector.broadcast %a : i8 to vector<2x4xi8>
371+
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
372+
return %r : vector<2x4xi32>
373+
}
374+
375+
// -----
376+
377+
func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
378+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
379+
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
380+
%b = vector.broadcast %a : i8 to vector<2x[4]xi8>
381+
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
382+
return %r : vector<2x[4]xi32>
383+
}
384+
324385
//===----------------------------------------------------------------------===//
325386
// [Pattern: ReorderElementwiseOpsOnTranspose]
326387
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)