Skip to content

Commit 0c0f415

Browse files
committed
Fix strides used by AMX lowering.
Lowering of tile stores and load uses the size of the last memref dimension as a stride and ignores actual strides specified in the memref. This causes unexpected results when actual stride doesn't match the last dimension size. Signed-off-by: Ilya Enkovich <[email protected]>
1 parent a4ace3d commit 0c0f415

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,25 @@ Value getStride(ConversionPatternRewriter &rewriter,
5555
const LLVMTypeConverter &typeConverter, MemRefType mType,
5656
Value base, Location loc) {
5757
assert(mType.getRank() >= 2);
58-
int64_t last = mType.getRank() - 1;
58+
int64_t preLast = mType.getRank() - 2;
5959
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
6060
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
6161
assert(llvm::isPowerOf2_64(width) && width >= 8);
6262
unsigned bytes = width >> 3;
63-
if (mType.isDynamicDim(last)) {
64-
// Dynamic size needs code to compute the stride at runtime.
63+
int64_t offset;
64+
SmallVector<int64_t, 4> strides;
65+
getStridesAndOffset(mType, strides, offset);
66+
if (strides[preLast] == ShapedType::kDynamic) {
67+
// Dynamic stride needs code to compute the stride at runtime.
6568
MemRefDescriptor memrefDescriptor(base);
6669
auto attr = rewriter.getI64IntegerAttr(bytes);
6770
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
6871
return rewriter.create<LLVM::MulOp>(
69-
loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
72+
loc, llvmInt64Type, scale,
73+
memrefDescriptor.stride(rewriter, loc, preLast));
7074
}
71-
// Use direct constant for static size.
72-
auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
75+
// Use direct constant for static stride.
76+
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
7377
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
7478
}
7579

mlir/test/Dialect/AMX/legalize-for-llvm.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,31 @@ func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
4343
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
4444
return
4545
}
46+
47+
// CHECK-LABEL: strides(
48+
// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
49+
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
50+
// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
51+
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
52+
// CHECK: llvm.mlir.constant(2 : i64) : i64
53+
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
54+
// CHECK: %[[STRIDE_1:.+]] = llvm.mul
55+
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
56+
// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
57+
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
58+
// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
59+
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
60+
// CHECK: llvm.mlir.constant(2 : i64) : i64
61+
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
62+
// CHECK: %[[STRIDE_2:.+]] = llvm.mul
63+
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
64+
func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
65+
%0 = arith.constant 0 : index
66+
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
67+
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
68+
%3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
69+
amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
70+
amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
71+
amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
72+
return
73+
}

0 commit comments

Comments
 (0)