Skip to content

Commit 41f66cd

Browse files
committed
Fix review comments.
Signed-off-by: Ilya Enkovich <[email protected]>
1 parent 0c0f415 commit 41f66cd

File tree

1 file changed

+23
-29
lines changed

1 file changed

+23
-29
lines changed

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,44 +37,38 @@ std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
3737
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
3838
}
3939

40-
/// Verifies if the stride matches proper tile access.
41-
LogicalResult verifyStride(MemRefType mType) {
42-
if (mType.getRank() < 2)
43-
return failure();
44-
int64_t last = mType.getRank() - 1;
45-
int64_t offset;
46-
SmallVector<int64_t, 4> strides;
47-
if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1)
48-
return failure();
49-
return success();
50-
}
51-
5240
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
5341
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
54-
Value getStride(ConversionPatternRewriter &rewriter,
55-
const LLVMTypeConverter &typeConverter, MemRefType mType,
56-
Value base, Location loc) {
57-
assert(mType.getRank() >= 2);
42+
/// Returns failure if proper stride couldn't be found.
43+
FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
44+
const LLVMTypeConverter &typeConverter,
45+
MemRefType mType, Value base, Location loc) {
46+
if (mType.getRank() < 2)
47+
return failure();
5848
int64_t preLast = mType.getRank() - 2;
5949
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
6050
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
6151
assert(llvm::isPowerOf2_64(width) && width >= 8);
6252
unsigned bytes = width >> 3;
6353
int64_t offset;
6454
SmallVector<int64_t, 4> strides;
65-
getStridesAndOffset(mType, strides, offset);
55+
if (failed(getStridesAndOffset(mType, strides, offset)) ||
56+
strides.back() != 1)
57+
return failure();
6658
if (strides[preLast] == ShapedType::kDynamic) {
6759
// Dynamic stride needs code to compute the stride at runtime.
6860
MemRefDescriptor memrefDescriptor(base);
6961
auto attr = rewriter.getI64IntegerAttr(bytes);
7062
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
71-
return rewriter.create<LLVM::MulOp>(
72-
loc, llvmInt64Type, scale,
73-
memrefDescriptor.stride(rewriter, loc, preLast));
63+
return rewriter
64+
.create<LLVM::MulOp>(loc, llvmInt64Type, scale,
65+
memrefDescriptor.stride(rewriter, loc, preLast))
66+
.getResult();
7467
}
7568
// Use direct constant for static stride.
7669
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
77-
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
70+
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
71+
.getResult();
7872
}
7973

8074
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
@@ -106,16 +100,16 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
106100
std::pair<Value, Value> tsz =
107101
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
108102
// Determine stride.
109-
if (failed(verifyStride(mType)))
103+
auto stride = getStride(rewriter, *getTypeConverter(), mType,
104+
adaptor.getBase(), op.getLoc());
105+
if (failed(stride))
110106
return failure();
111-
Value stride = getStride(rewriter, *getTypeConverter(), mType,
112-
adaptor.getBase(), op.getLoc());
113107
// Replace operation with intrinsic.
114108
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
115109
adaptor.getIndices(), rewriter);
116110
Type resType = typeConverter->convertType(vType);
117111
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
118-
op, resType, tsz.first, tsz.second, ptr, stride);
112+
op, resType, tsz.first, tsz.second, ptr, stride.value());
119113
return success();
120114
}
121115
};
@@ -132,15 +126,15 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
132126
std::pair<Value, Value> tsz =
133127
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
134128
// Determine stride.
135-
if (failed(verifyStride(mType)))
129+
auto stride = getStride(rewriter, *getTypeConverter(), mType,
130+
adaptor.getBase(), op.getLoc());
131+
if (failed(stride))
136132
return failure();
137-
Value stride = getStride(rewriter, *getTypeConverter(), mType,
138-
adaptor.getBase(), op.getLoc());
139133
// Replace operation with intrinsic.
140134
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
141135
adaptor.getIndices(), rewriter);
142136
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
143-
op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
137+
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
144138
return success();
145139
}
146140
};

0 commit comments

Comments
 (0)