From 0c0f415e4b3b34768102dbea162bcb6e29f53de1 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 23 Oct 2024 16:59:28 +0000 Subject: [PATCH 1/2] Fix strides used by AMX lowering. Lowering of tile stores and load uses the size of the last memref dimension as a stride and ignores actual strides specified in the memref. This causes unexpected results when actual stride doesn't match the last dimension size. Signed-off-by: Ilya Enkovich --- .../AMX/Transforms/LegalizeForLLVMExport.cpp | 16 +++++++---- mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 28 +++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index c8cfcc3d945be..14f28c436a1ff 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -55,21 +55,25 @@ Value getStride(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, MemRefType mType, Value base, Location loc) { assert(mType.getRank() >= 2); - int64_t last = mType.getRank() - 1; + int64_t preLast = mType.getRank() - 2; Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); unsigned width = mType.getElementType().getIntOrFloatBitWidth(); assert(llvm::isPowerOf2_64(width) && width >= 8); unsigned bytes = width >> 3; - if (mType.isDynamicDim(last)) { - // Dynamic size needs code to compute the stride at runtime. + int64_t offset; + SmallVector strides; + getStridesAndOffset(mType, strides, offset); + if (strides[preLast] == ShapedType::kDynamic) { + // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); auto attr = rewriter.getI64IntegerAttr(bytes); Value scale = rewriter.create(loc, llvmInt64Type, attr); return rewriter.create( - loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last)); + loc, llvmInt64Type, scale, + memrefDescriptor.stride(rewriter, loc, preLast)); } - // Use direct constant for static size. - auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes); + // Use direct constant for static stride. + auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); return rewriter.create(loc, llvmInt64Type, attr); } diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir index 992203153939f..3cacbd0044f82 100644 --- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir @@ -43,3 +43,31 @@ func.func @mulf(%arg0: memref, %arg1: memref) { amx.tile_store %arg1[%0, %0], %4 : memref, vector<16x16xf32> return } + +// CHECK-LABEL: strides( +// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64 +// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]] +// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64 +// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]] +// CHECK: llvm.mlir.constant(2 : i64) : i64 +// CHECK: llvm.extractvalue %{{.+}}[4, 0] +// CHECK: %[[STRIDE_1:.+]] = llvm.mul +// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]] +// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64 +// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]] +// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64 +// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]] +// CHECK: llvm.mlir.constant(2 : i64) : i64 +// CHECK: llvm.extractvalue %{{.+}}[4, 0] +// CHECK: %[[STRIDE_2:.+]] = llvm.mul +// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]] +func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) { + %0 = arith.constant 0 : index + %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16> + %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16> + %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16> + amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16> + amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16> + amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16> + return +} From 41f66cd2c60ab3af59178b23f5038cd8f6fd9074 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 28 Oct 2024 19:59:43 +0000 Subject: [PATCH 2/2] Fix review comments. Signed-off-by: Ilya Enkovich --- .../AMX/Transforms/LegalizeForLLVMExport.cpp | 52 ++++++++----------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index 14f28c436a1ff..46c7bfbf3ffcc 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -37,24 +37,14 @@ std::pair getTileSizes(ConversionPatternRewriter &rewriter, rewriter.create(loc, llvmInt16Type, nattr)); } -/// Verifies if the stride matches proper tile access. -LogicalResult verifyStride(MemRefType mType) { - if (mType.getRank() < 2) - return failure(); - int64_t last = mType.getRank() - 1; - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1) - return failure(); - return success(); -} - /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer /// shape may "envelop" the actual tile shape, and may be dynamically sized. -Value getStride(ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &typeConverter, MemRefType mType, - Value base, Location loc) { - assert(mType.getRank() >= 2); +/// Returns failure if proper stride couldn't be found. +FailureOr getStride(ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &typeConverter, + MemRefType mType, Value base, Location loc) { + if (mType.getRank() < 2) + return failure(); int64_t preLast = mType.getRank() - 2; Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); unsigned width = mType.getElementType().getIntOrFloatBitWidth(); @@ -62,19 +52,23 @@ Value getStride(ConversionPatternRewriter &rewriter, unsigned bytes = width >> 3; int64_t offset; SmallVector strides; - getStridesAndOffset(mType, strides, offset); + if (failed(getStridesAndOffset(mType, strides, offset)) || + strides.back() != 1) + return failure(); if (strides[preLast] == ShapedType::kDynamic) { // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); auto attr = rewriter.getI64IntegerAttr(bytes); Value scale = rewriter.create(loc, llvmInt64Type, attr); - return rewriter.create( - loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)); + return rewriter + .create(loc, llvmInt64Type, scale, + memrefDescriptor.stride(rewriter, loc, preLast)) + .getResult(); } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); - return rewriter.create(loc, llvmInt64Type, attr); + return rewriter.create(loc, llvmInt64Type, attr) + .getResult(); } struct TileZeroConversion : public ConvertOpToLLVMPattern { @@ -106,16 +100,16 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern { std::pair tsz = getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); // Determine stride. - if (failed(verifyStride(mType))) + auto stride = getStride(rewriter, *getTypeConverter(), mType, + adaptor.getBase(), op.getLoc()); + if (failed(stride)) return failure(); - Value stride = getStride(rewriter, *getTypeConverter(), mType, - adaptor.getBase(), op.getLoc()); // Replace operation with intrinsic. Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), adaptor.getIndices(), rewriter); Type resType = typeConverter->convertType(vType); rewriter.replaceOpWithNewOp( - op, resType, tsz.first, tsz.second, ptr, stride); + op, resType, tsz.first, tsz.second, ptr, stride.value()); return success(); } }; @@ -132,15 +126,15 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern { std::pair tsz = getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); // Determine stride. - if (failed(verifyStride(mType))) + auto stride = getStride(rewriter, *getTypeConverter(), mType, + adaptor.getBase(), op.getLoc()); + if (failed(stride)) return failure(); - Value stride = getStride(rewriter, *getTypeConverter(), mType, - adaptor.getBase(), op.getLoc()); // Replace operation with intrinsic. Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp( - op, tsz.first, tsz.second, ptr, stride, adaptor.getVal()); + op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal()); return success(); } };