Skip to content

Commit 4e2c09c

Browse files
committed
Revert "[mlir][vector] Avoid setting padding by default to 0 in vector.transfer_read prefer ub.poison (llvm#146088)"
This reverts commit 878d359.
1 parent d4129d3 commit 4e2c09c

File tree

15 files changed

+78
-107
lines changed

15 files changed

+78
-107
lines changed

mlir/include/mlir/Dialect/Arith/IR/Arith.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,6 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
154154
Value lhs, Value rhs);
155155

156156
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
157-
158-
/// Creates an `arith.constant` operation with a zero value of type `type`. This
159-
/// method asserts if `type` is invalid for representing zero with
160-
/// `arith.constant`.
161-
Value getZeroConstant(OpBuilder &builder, Location loc, Type type);
162157
} // namespace arith
163158
} // namespace mlir
164159

mlir/include/mlir/Dialect/Vector/IR/Vector.td

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ def Vector_Dialect : Dialect {
2121

2222
let useDefaultAttributePrinterParser = 1;
2323
let hasConstantMaterializer = 1;
24-
let dependentDialects = [
25-
"arith::ArithDialect",
26-
"ub::UBDialect"
27-
];
24+
let dependentDialects = ["arith::ArithDialect"];
2825
}
2926

3027
// Base class for Vector dialect ops.

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,29 +1543,30 @@ def Vector_TransferReadOp :
15431543
}];
15441544

15451545
let builders = [
1546-
/// 1. Builder that sets padding to `padding` or poison if not provided and
1547-
/// an empty mask (variant with attrs).
1546+
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
15481547
OpBuilder<(ins "VectorType":$vectorType,
15491548
"Value":$source,
15501549
"ValueRange":$indices,
1551-
"std::optional<Value>":$padding,
15521550
"AffineMapAttr":$permutationMapAttr,
15531551
"ArrayAttr":$inBoundsAttr)>,
1554-
/// 2. Builder that sets padding to `padding` or poison if not provided and
1555-
/// an empty mask (variant without attrs).
1552+
/// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
15561553
OpBuilder<(ins "VectorType":$vectorType,
15571554
"Value":$source,
15581555
"ValueRange":$indices,
1559-
"std::optional<Value>":$padding,
15601556
"AffineMap":$permutationMap,
15611557
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
1562-
/// 3. Builder that sets padding to `padding` or poison if not provided and
1563-
/// permutation map to 'getMinorIdentityMap'.
1558+
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
15641559
OpBuilder<(ins "VectorType":$vectorType,
15651560
"Value":$source,
15661561
"ValueRange":$indices,
1567-
"std::optional<Value>":$padding,
1568-
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>
1562+
"Value":$padding,
1563+
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
1564+
/// 4. Builder that sets padding to zero and permutation map to
1565+
/// 'getMinorIdentityMap'.
1566+
OpBuilder<(ins "VectorType":$vectorType,
1567+
"Value":$source,
1568+
"ValueRange":$indices,
1569+
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
15691570
];
15701571

15711572
let extraClassDeclaration = [{

mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,8 +1257,7 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
12571257
LLVM_DEBUG(permutationMap.print(dbgs()));
12581258

12591259
auto transfer = state.builder.create<vector::TransferReadOp>(
1260-
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
1261-
/*padding=*/std::nullopt, permutationMap);
1260+
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
12621261

12631262
// Register replacement for future uses in the scope.
12641263
state.registerOpVectorReplacement(loadOp, transfer);

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,6 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
292292
return false;
293293
}
294294

295-
Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
296-
Type type) {
297-
// TODO: Incorporate this check to `FloatAttr::get*`.
298-
assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
299-
"type doesn't have a zero representation");
300-
TypedAttr zeroAttr = builder.getZeroAttr(type);
301-
assert(zeroAttr && "unsupported type for zero attribute");
302-
return builder.create<arith::ConstantOp>(loc, zeroAttr);
303-
}
304-
305295
//===----------------------------------------------------------------------===//
306296
// AddIOp
307297
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,6 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
426426
// Create the new `transfer_read`.
427427
auto newReadOp = rewriter.create<vector::TransferReadOp>(
428428
readOp.getLoc(), collapsedVT, collapsedMem, indices,
429-
readOp.getPadding(),
430429
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
431430

432431
// Cast back to the original vector type.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,18 +1183,14 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11831183
auto srcRank = extractOp.getTensor().getType().getRank();
11841184
SmallVector<bool> inBounds(dstRank, true);
11851185

1186-
// Get the value to pad transfer reads with 0.
1187-
Value padding =
1188-
arith::getZeroConstant(rewriter, loc, resultType.getElementType());
1189-
11901186
// 2a. Handle scalar broadcast access.
11911187
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
11921188
MLIRContext *ctx = rewriter.getContext();
11931189
SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
11941190
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
11951191

11961192
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1197-
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
1193+
loc, resultType, extractOp.getTensor(), transferReadIdxs,
11981194
permutationMap, inBounds);
11991195

12001196
// Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1231,8 +1227,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12311227
}
12321228

12331229
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1234-
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
1235-
permutationMap, inBounds);
1230+
loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1231+
inBounds);
12361232

12371233
LDBG("Vectorised as contiguous load: " << extractOp);
12381234
return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1388,7 +1384,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13881384
/// performed to the maximal common vector size implied by the `linalgOp`
13891385
/// iteration space. This eager broadcasting is introduced in the
13901386
/// permutation_map of the vector.transfer_read operations. The eager
1391-
/// broadcasting makes it trivial to determine where broadcast, transposes and
1387+
/// broadcasting makes it trivial to detrmine where broadcast, transposes and
13921388
/// reductions should occur, without any bookkeeping. The tradeoff is that, in
13931389
/// the absence of good canonicalizations, the amount of work increases.
13941390
/// This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1443,8 +1439,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14431439
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
14441440

14451441
Operation *read = rewriter.create<vector::TransferReadOp>(
1446-
loc, readType, opOperand->get(), indices,
1447-
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType), readMap);
1442+
loc, readType, opOperand->get(), indices, readMap);
14481443
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
14491444
Value readValue = read->getResult(0);
14501445

@@ -2646,7 +2641,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
26462641

26472642
Value readValue = rewriter.create<vector::TransferReadOp>(
26482643
loc, readType, copyOp.getSource(), indices,
2649-
/*padding=*/arith::getZeroConstant(rewriter, loc, srcElementType),
26502644
rewriter.getMultiDimIdentityMap(srcType.getRank()));
26512645
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
26522646
readValue =
@@ -3493,18 +3487,15 @@ struct Conv1DGenerator
34933487
SmallVector<Value> resPadding(resShape.size(), zero);
34943488

34953489
// Read the whole lhs, rhs and res in one shot (with zero padding).
3496-
Value lhs = rewriter.create<vector::TransferReadOp>(
3497-
loc, lhsType, lhsShaped, lhsPadding,
3498-
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3490+
Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3491+
lhsPadding);
34993492
// This is needed only for Conv.
35003493
Value rhs = nullptr;
35013494
if (oper == ConvOperationKind::Conv)
3502-
rhs = rewriter.create<vector::TransferReadOp>(
3503-
loc, rhsType, rhsShaped, rhsPadding,
3504-
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3505-
Value res = rewriter.create<vector::TransferReadOp>(
3506-
loc, resType, resShaped, resPadding,
3507-
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3495+
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3496+
rhsPadding);
3497+
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3498+
resPadding);
35083499

35093500
// The base vectorization case for channeled convolution is input:
35103501
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3751,22 +3742,19 @@ struct Conv1DGenerator
37513742
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
37523743
// 0].
37533744
Value lhs = rewriter.create<vector::TransferReadOp>(
3754-
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3755-
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3745+
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
37563746
auto maybeMaskedLhs = maybeMaskXferOp(
37573747
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
37583748

37593749
// Read rhs slice of size {kw, c} @ [0, 0].
3760-
Value rhs = rewriter.create<vector::TransferReadOp>(
3761-
loc, rhsType, rhsShaped, ValueRange{zero, zero},
3762-
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3750+
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3751+
ValueRange{zero, zero});
37633752
auto maybeMaskedRhs = maybeMaskXferOp(
37643753
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
37653754

37663755
// Read res slice of size {n, w, c} @ [0, 0, 0].
37673756
Value res = rewriter.create<vector::TransferReadOp>(
3768-
loc, resType, resShaped, ValueRange{zero, zero, zero},
3769-
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3757+
loc, resType, resShaped, ValueRange{zero, zero, zero});
37703758
auto maybeMaskedRes = maybeMaskXferOp(
37713759
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
37723760

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4291,39 +4291,33 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
42914291
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
42924292
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
42934293
VectorType vectorType, Value source,
4294-
ValueRange indices, std::optional<Value> padding,
4295-
AffineMapAttr permutationMapAttr,
4294+
ValueRange indices, AffineMapAttr permutationMapAttr,
42964295
/*optional*/ ArrayAttr inBoundsAttr) {
4297-
42984296
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4299-
if (!padding)
4300-
padding = builder.create<ub::PoisonOp>(result.location, elemType);
4297+
Value padding = builder.create<arith::ConstantOp>(
4298+
result.location, elemType, builder.getZeroAttr(elemType));
43014299
build(builder, result, vectorType, source, indices, permutationMapAttr,
4302-
*padding, /*mask=*/Value(), inBoundsAttr);
4300+
padding, /*mask=*/Value(), inBoundsAttr);
43034301
}
43044302

43054303
/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
43064304
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
43074305
VectorType vectorType, Value source,
4308-
ValueRange indices, std::optional<Value> padding,
4309-
AffineMap permutationMap,
4306+
ValueRange indices, AffineMap permutationMap,
43104307
std::optional<ArrayRef<bool>> inBounds) {
43114308
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
43124309
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
43134310
? builder.getBoolArrayAttr(inBounds.value())
43144311
: builder.getBoolArrayAttr(
43154312
SmallVector<bool>(vectorType.getRank(), false));
4316-
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4317-
if (!padding)
4318-
padding = builder.create<ub::PoisonOp>(result.location, elemType);
4319-
build(builder, result, vectorType, source, indices, *padding,
4320-
permutationMapAttr, inBoundsAttr);
4313+
build(builder, result, vectorType, source, indices, permutationMapAttr,
4314+
inBoundsAttr);
43214315
}
43224316

43234317
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
43244318
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
43254319
VectorType vectorType, Value source,
4326-
ValueRange indices, std::optional<Value> padding,
4320+
ValueRange indices, Value padding,
43274321
std::optional<ArrayRef<bool>> inBounds) {
43284322
AffineMap permutationMap = getTransferMinorIdentityMap(
43294323
llvm::cast<ShapedType>(source.getType()), vectorType);
@@ -4332,14 +4326,23 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
43324326
? builder.getBoolArrayAttr(inBounds.value())
43334327
: builder.getBoolArrayAttr(
43344328
SmallVector<bool>(vectorType.getRank(), false));
4335-
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4336-
if (!padding)
4337-
padding = builder.create<ub::PoisonOp>(result.location, elemType);
43384329
build(builder, result, vectorType, source, indices, permutationMapAttr,
4339-
*padding,
4330+
padding,
43404331
/*mask=*/Value(), inBoundsAttr);
43414332
}
43424333

4334+
/// 4. Builder that sets padding to zero and permutation map to
4335+
/// 'getMinorIdentityMap'.
4336+
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4337+
VectorType vectorType, Value source,
4338+
ValueRange indices,
4339+
std::optional<ArrayRef<bool>> inBounds) {
4340+
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4341+
Value padding = builder.create<arith::ConstantOp>(
4342+
result.location, elemType, builder.getZeroAttr(elemType));
4343+
build(builder, result, vectorType, source, indices, padding, inBounds);
4344+
}
4345+
43434346
template <typename EmitFun>
43444347
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
43454348
EmitFun emitOpError) {

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper {
173173
}
174174
SmallVector<bool> inBounds(indices.size(), true);
175175
return b.create<vector::TransferReadOp>(
176-
loc, cast<VectorType>(type), buffer, indices, /*padding=*/std::nullopt,
176+
loc, cast<VectorType>(type), buffer, indices,
177177
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
178178
}
179179

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,7 @@ class FlattenContiguousRowMajorTransferReadPattern
660660
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
661661
vectorType.getElementType());
662662
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
663-
loc, flatVectorType, collapsedSource, collapsedIndices,
664-
transferReadOp.getPadding(), collapsedMap);
663+
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
665664
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
666665

667666
// 4. Replace the old transfer_read with the new one reading from the

0 commit comments

Comments
 (0)