From 3e2e4b513241ef47405a252532ba1352f12df04a Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 29 Jan 2025 05:14:11 +0000 Subject: [PATCH 1/8] First commit --- .../Vector/Transforms/VectorRewritePatterns.h | 6 +- .../Transforms/VectorEmulateNarrowType.cpp | 59 +++++++++++++++++-- .../Dialect/MemRef/TestEmulateNarrowType.cpp | 8 ++- 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index a59f06f3c1ef1..43478aacb50a1 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Appends patterns for emulating vector operations over narrow types with ops -/// over wider types. +/// over wider types. The `useAtomicWrites` indicates whether to use +/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a +/// rmw sequence otherwise. void populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, bool useAtomicWrites = true); /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of /// vector operations comprising `shuffle` and `bitwise` ops. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 7ca88f1e0a0df..8317317edb915 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -363,6 +363,29 @@ static void atomicStore(OpBuilder &builder, Location loc, builder.create(loc, scalarMaskedValue); } +/// Generate a non-atomic read-modify-write sequence for subbyte storing. +/// It has similar logic to `atomicStore`, but without the atomicity. +static void rmwStore(OpBuilder &builder, Location loc, + MemRefValue linearizedMemref, Value linearizedIndex, + VectorValue valueToStore, Value mask) { + assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); + + // Load the original value from memory, and cast it to the original element + // type. + auto oneElemVecType = + VectorType::get({1}, linearizedMemref.getType().getElementType()); + Value origVecValue = builder.create( + loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex}); + origVecValue = builder.create(loc, valueToStore.getType(), + origVecValue); + + // Construct the final masked value and yield it. + Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask, + origVecValue, valueToStore); + builder.create(loc, maskedValue, linearizedMemref, + linearizedIndex); +} + /// Extract `sliceNumElements` from source `vector` at `extractOffset`, /// and insert it into an empty vector at `insertOffset`. /// Inputs: @@ -405,6 +428,10 @@ namespace { struct ConvertVectorStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; + ConvertVectorStore(MLIRContext *context, bool useAtomicWrites) + : OpConversionPattern(context), + useAtomicWrites_(useAtomicWrites) {} + LogicalResult matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -611,13 +638,31 @@ struct ConvertVectorStore final : OpConversionPattern { auto backMask = rewriter.create( loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); - atomicStore(rewriter, loc, memrefBase, currentDestIndex, - cast(subWidthStorePart), backMask.getResult()); + subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex, + cast(subWidthStorePart), + backMask.getResult()); } rewriter.eraseOp(op); return success(); } + + /// Store a subbyte-sized value to memory, with a mask. Depending on the + /// configuration, it could be an atomic store or a non-atomic RMW sequence. + template + void subEmulatedWidthStore(Args &&...args) const { + static_assert( + std::is_same_v && + "`atomicStore` and `rmwStore` must have same signature, as per " + "the design to keep the code clean, which one to call is " + "determined by the `useAtomicWrites` flag."); + std::function storeFunc = + useAtomicWrites_ ? atomicStore : rmwStore; + storeFunc(std::forward(args)...); + } + +private: + const bool useAtomicWrites_; }; //===----------------------------------------------------------------------===// @@ -1930,12 +1975,18 @@ struct RewriteVectorTranspose : OpRewritePattern { void vector::populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns) { + RewritePatternSet &patterns, bool useAtomicWrites) { // Populate `vector.*` conversion patterns. - patterns.add( typeConverter, patterns.getContext()); + + // Populate `vector.*` store conversion patterns. The caller can choose + // to avoid emitting atomic operations and reduce it to load-modify-write + // sequence for stores if it is known there are no thread contentions. + patterns.insert(patterns.getContext(), useAtomicWrites); } void vector::populateVectorNarrowTypeRewritePatterns( diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp index 7401e470ed4f2..9a3fac623fbd7 100644 --- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); - vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns); + vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns, + atomicStore); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); @@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass *this, "skip-memref-type-conversion", llvm::cl::desc("disable memref type conversion (to test failures)"), llvm::cl::init(false)}; + + Option atomicStore{ + *this, "atomic-store", + llvm::cl::desc("use atomic store instead of load-modify-write"), + llvm::cl::init(true)}; }; } // namespace From 66ecff4e0487f8520e1591db702e30dd8b732ca3 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 29 Jan 2025 05:47:51 +0000 Subject: [PATCH 2/8] updates --- .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 8317317edb915..82d8a6ffcc17c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -380,8 +380,9 @@ static void rmwStore(OpBuilder &builder, Location loc, origVecValue); // Construct the final masked value and yield it. - Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask, - origVecValue, valueToStore); + Value maskedValue = + downcastSelectAndUpcast(builder, loc, valueToStore.getType(), + oneElemVecType, mask, valueToStore, origVecValue); builder.create(loc, maskedValue, linearizedMemref, linearizedIndex); } From f2d5e8ba19625a90d5f32cda3c9cd337a36c339d Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 29 Jan 2025 06:16:54 +0000 Subject: [PATCH 3/8] linting --- .../Transforms/VectorEmulateNarrowType.cpp | 6 +- ...late-narrow-type-unaligned-non-atomic.mlir | 119 ++++++++++++++++++ 2 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 82d8a6ffcc17c..c848d3c0ca98a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -429,7 +429,7 @@ namespace { struct ConvertVectorStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; - ConvertVectorStore(MLIRContext *context, bool useAtomicWrites) + ConvertVectorStore(MLIRContext *context, bool useAtomicWrites) : OpConversionPattern(context), useAtomicWrites_(useAtomicWrites) {} @@ -583,8 +583,8 @@ struct ConvertVectorStore final : OpConversionPattern { extractSliceIntoByte(rewriter, loc, valueToStore, 0, frontSubWidthStoreElem, *foldedNumFrontPadElems); - atomicStore(rewriter, loc, memrefBase, currentDestIndex, - cast(value), frontMask.getResult()); + subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex, + cast(value), frontMask.getResult()); } if (currentSourceIndex >= origElements) { diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir new file mode 100644 index 0000000000000..79f8869d043ee --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s + +// TODO: remove memref.alloc() in the tests to eliminate noises. +// memref.alloc exists here because sub-byte vector data types such as i2 +// are currently not supported as input arguments. + +func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) { + %0 = memref.alloc() : memref<3x3xi2> + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2> + return +} +// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw. + +// CHECK: func @vector_store_i2_const_index_two_rmw( +// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index + +// Part 1 RMW sequence +// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] +// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2> +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]] +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]] + +// Part 2 RMW sequence +// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index +// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2> +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1> +// CHECK: %[[LOAD2:.+]] = vector.load +// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]] +// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]] +// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]] + + +// ----- + +func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) { + %0 = memref.alloc() : memref<3x7xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2> + return +} + +// CHECK: func @vector_store_i2_rmw( +// CHECK-SAME: %[[ARG0:.+]]: +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] +// CHECK-SAME: {offsets = [3], strides = [1]} +// First sub-width RMW: +// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]] +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]] +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]] + +// Full-width store: +// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]] +// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]] +// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]] + +// Second sub-width RMW: +// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]] +// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]] +// CHECK-SAME: {offsets = [0], strides = [1]} +// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> +// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]] +// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]] +// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST1]] +// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]] +// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]] + +// ----- + +func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) { + %0 = memref.alloc() : memref<4x1xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2> + return +} + +// in this test, only emit 1 rmw store +// CHECK: func @vector_store_i2_single_rmw( +// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8> +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]] +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]] + From c75f899316d9dd5bbd362569b145fcdd54fafeab Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 29 Jan 2025 09:20:33 +0000 Subject: [PATCH 4/8] update comments --- .../Transforms/VectorEmulateNarrowType.cpp | 5 +---- ...late-narrow-type-unaligned-non-atomic.mlir | 22 +++++++++++-------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index c848d3c0ca98a..00019d8c2d4bc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -364,14 +364,12 @@ static void atomicStore(OpBuilder &builder, Location loc, } /// Generate a non-atomic read-modify-write sequence for subbyte storing. -/// It has similar logic to `atomicStore`, but without the atomicity. +/// It has similar logic to `atomicStore`, but without atomicity. static void rmwStore(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask) { assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); - // Load the original value from memory, and cast it to the original element - // type. auto oneElemVecType = VectorType::get({1}, linearizedMemref.getType().getElementType()); Value origVecValue = builder.create( @@ -379,7 +377,6 @@ static void rmwStore(OpBuilder &builder, Location loc, origVecValue = builder.create(loc, valueToStore.getType(), origVecValue); - // Construct the final masked value and yield it. Value maskedValue = downcastSelectAndUpcast(builder, loc, valueToStore.getType(), oneElemVecType, mask, valueToStore, origVecValue); diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir index 79f8869d043ee..84cae7d922b38 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir @@ -4,16 +4,18 @@ // memref.alloc exists here because sub-byte vector data types such as i2 // are currently not supported as input arguments. -func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) { +func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { %0 = memref.alloc() : memref<3x3xi2> %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2> return } -// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw. +// In this example, emit two RMW stores without full-width store. +// Store bit [12:18), byte [1:2] to a 3-byte vector, both bytes are +// accessed partially. -// CHECK: func @vector_store_i2_const_index_two_rmw( +// CHECK: func @vector_store_i2_const_index_two_partial_stores( // CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>) // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> // CHECK: %[[C1:.+]] = arith.constant 1 : index @@ -47,7 +49,7 @@ func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) { // ----- -func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) { +func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) { %0 = memref.alloc() : memref<3x7xi2> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -55,7 +57,9 @@ func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) { return } -// CHECK: func @vector_store_i2_rmw( +// In this example, emit two RMW stores and one full-width store. + +// CHECK: func @vector_store_i2_two_partial_one_full_stores( // CHECK-SAME: %[[ARG0:.+]]: // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8> // CHECK: %[[C1:.+]] = arith.constant 1 : index @@ -94,7 +98,7 @@ func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) { // ----- -func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) { +func.func @vector_store_i2_one_partial_store(%arg0: vector<1xi2>) { %0 = memref.alloc() : memref<4x1xi2> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -102,8 +106,9 @@ func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) { return } -// in this test, only emit 1 rmw store -// CHECK: func @vector_store_i2_single_rmw( +// in this test, only emit partial RMW store as the store is within one byte. + +// CHECK: func @vector_store_i2_one_partial_store( // CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>) // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8> // CHECK: %[[C0:.+]] = arith.constant 0 : index @@ -116,4 +121,3 @@ func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) { // CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]] // CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] // CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]] - From 562d87e3ce1fe0f2279ee1ca4e74d8f140670597 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Mon, 3 Feb 2025 21:39:27 -0800 Subject: [PATCH 5/8] Rename --- .../Transforms/VectorEmulateNarrowType.cpp | 35 +++++++------------ 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 00019d8c2d4bc..edc8881d6919e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, /// /// Result: /// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>) -static void atomicStore(OpBuilder &builder, Location loc, - MemRefValue linearizedMemref, Value storeIdx, - VectorValue valueToStore, Value mask) { +static void atomicRMWStore(OpBuilder &builder, Location loc, + MemRefValue linearizedMemref, Value storeIdx, + VectorValue valueToStore, Value mask) { assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); // Create an atomic load-modify-write region using @@ -364,10 +364,11 @@ static void atomicStore(OpBuilder &builder, Location loc, } /// Generate a non-atomic read-modify-write sequence for subbyte storing. -/// It has similar logic to `atomicStore`, but without atomicity. -static void rmwStore(OpBuilder &builder, Location loc, - MemRefValue linearizedMemref, Value linearizedIndex, - VectorValue valueToStore, Value mask) { +/// It has similar logic to `atomicRMWStore`, but without atomicity. +static void nonAtomicRMWStore(OpBuilder &builder, Location loc, + MemRefValue linearizedMemref, + Value linearizedIndex, VectorValue valueToStore, + Value mask) { assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); auto oneElemVecType = @@ -580,8 +581,10 @@ struct ConvertVectorStore final : OpConversionPattern { extractSliceIntoByte(rewriter, loc, valueToStore, 0, frontSubWidthStoreElem, *foldedNumFrontPadElems); - subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex, - cast(value), frontMask.getResult()); + auto storeFunc = useAtomicWrites_ ? atomicRMWStore : nonAtomicRMWStore; + + storeFunc(rewriter, loc, memrefBase, currentDestIndex, + cast(value), frontMask.getResult()); } if (currentSourceIndex >= origElements) { @@ -645,20 +648,6 @@ struct ConvertVectorStore final : OpConversionPattern { return success(); } - /// Store a subbyte-sized value to memory, with a mask. Depending on the - /// configuration, it could be an atomic store or a non-atomic RMW sequence. - template - void subEmulatedWidthStore(Args &&...args) const { - static_assert( - std::is_same_v && - "`atomicStore` and `rmwStore` must have same signature, as per " - "the design to keep the code clean, which one to call is " - "determined by the `useAtomicWrites` flag."); - std::function storeFunc = - useAtomicWrites_ ? atomicStore : rmwStore; - storeFunc(std::forward(args)...); - } - private: const bool useAtomicWrites_; }; From c9e2754f04f5dbfad846fc0a047433929acd29cb Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 4 Feb 2025 05:15:27 -0800 Subject: [PATCH 6/8] Update name --- .../Vector/Transforms/VectorRewritePatterns.h | 8 ++--- .../Transforms/VectorEmulateNarrowType.cpp | 36 +++++++++---------- .../Dialect/MemRef/TestEmulateNarrowType.cpp | 11 +++--- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 43478aacb50a1..7de4a6a315750 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -364,12 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Appends patterns for emulating vector operations over narrow types with ops -/// over wider types. The `useAtomicWrites` indicates whether to use -/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a -/// rmw sequence otherwise. +/// over wider types. The `disableAtomicRMW` indicates whether to use a normal +/// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to +/// perform subbyte storing. void populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns, bool useAtomicWrites = true); + RewritePatternSet &patterns, bool disableAtomicRMW = false); /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of /// vector operations comprising `shuffle` and `bitwise` ops. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index edc8881d6919e..cef7f0cde4d7c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, /// /// Result: /// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>) -static void atomicRMWStore(OpBuilder &builder, Location loc, - MemRefValue linearizedMemref, Value storeIdx, - VectorValue valueToStore, Value mask) { +static void atomicRMW(OpBuilder &builder, Location loc, + MemRefValue linearizedMemref, Value storeIdx, + VectorValue valueToStore, Value mask) { assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); // Create an atomic load-modify-write region using @@ -363,12 +363,11 @@ static void atomicRMWStore(OpBuilder &builder, Location loc, builder.create(loc, scalarMaskedValue); } -/// Generate a non-atomic read-modify-write sequence for subbyte storing. -/// It has similar logic to `atomicRMWStore`, but without atomicity. -static void nonAtomicRMWStore(OpBuilder &builder, Location loc, - MemRefValue linearizedMemref, - Value linearizedIndex, VectorValue valueToStore, - Value mask) { +/// Generate a non-atomic read-modify-write sequence for storing to the emulated +/// type. It has similar logic to `atomicRMWStore`, but without atomicity. +static void nonAtomicRMW(OpBuilder &builder, Location loc, + MemRefValue linearizedMemref, Value linearizedIndex, + VectorValue valueToStore, Value mask) { assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); auto oneElemVecType = @@ -427,9 +426,9 @@ namespace { struct ConvertVectorStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; - ConvertVectorStore(MLIRContext *context, bool useAtomicWrites) + ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW) : OpConversionPattern(context), - useAtomicWrites_(useAtomicWrites) {} + disableAtomicRMW(disableAtomicRMW) {} LogicalResult matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor, @@ -557,6 +556,8 @@ struct ConvertVectorStore final : OpConversionPattern { auto subWidthStoreMaskType = VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type()); + auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW; + // 1. Partial width store for the leading byte. // When the store address is not aligned to emulated width boundary, deal // with the unaligned part so that the rest elements are aligned to width @@ -581,8 +582,6 @@ struct ConvertVectorStore final : OpConversionPattern { extractSliceIntoByte(rewriter, loc, valueToStore, 0, frontSubWidthStoreElem, *foldedNumFrontPadElems); - auto storeFunc = useAtomicWrites_ ? atomicRMWStore : nonAtomicRMWStore; - storeFunc(rewriter, loc, memrefBase, currentDestIndex, cast(value), frontMask.getResult()); } @@ -639,9 +638,8 @@ struct ConvertVectorStore final : OpConversionPattern { auto backMask = rewriter.create( loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); - subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex, - cast(subWidthStorePart), - backMask.getResult()); + storeFunc(rewriter, loc, memrefBase, currentDestIndex, + cast(subWidthStorePart), backMask.getResult()); } rewriter.eraseOp(op); @@ -649,7 +647,7 @@ struct ConvertVectorStore final : OpConversionPattern { } private: - const bool useAtomicWrites_; + const bool disableAtomicRMW; }; //===----------------------------------------------------------------------===// @@ -1962,7 +1960,7 @@ struct RewriteVectorTranspose : OpRewritePattern { void vector::populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns, bool useAtomicWrites) { + RewritePatternSet &patterns, bool disableAtomicRMW) { // Populate `vector.*` conversion patterns. // TODO: #119553 support atomicity @@ -1973,7 +1971,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns( // Populate `vector.*` store conversion patterns. The caller can choose // to avoid emitting atomic operations and reduce it to load-modify-write // sequence for stores if it is known there are no thread contentions. - patterns.insert(patterns.getContext(), useAtomicWrites); + patterns.insert(patterns.getContext(), disableAtomicRMW); } void vector::populateVectorNarrowTypeRewritePatterns( diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp index 9a3fac623fbd7..ba2ea40e83d96 100644 --- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -100,7 +100,7 @@ struct TestEmulateNarrowTypePass arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns, - atomicStore); + disableAtomicRMW); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); @@ -120,10 +120,11 @@ struct TestEmulateNarrowTypePass llvm::cl::desc("disable memref type conversion (to test failures)"), llvm::cl::init(false)}; - Option atomicStore{ - *this, "atomic-store", - llvm::cl::desc("use atomic store instead of load-modify-write"), - llvm::cl::init(true)}; + Option disableAtomicRMW{ + *this, "disable-atomic-rmw", + llvm::cl::desc("disable atomic read-modify-write and prefer generating " + "normal sequence"), + llvm::cl::init(false)}; }; } // namespace From b63c9fe163539bd37e8d003ce45a2e3f2342d5da Mon Sep 17 00:00:00 2001 From: Alan Li Date: Thu, 6 Feb 2025 17:24:49 +0000 Subject: [PATCH 7/8] update according to comments --- .../Transforms/VectorEmulateNarrowType.cpp | 2 +- ...late-narrow-type-unaligned-non-atomic.mlir | 31 +++++++++++-------- .../vector-emulate-narrow-type-unaligned.mlir | 7 ++--- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index cef7f0cde4d7c..acd4ac3496789 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1969,7 +1969,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns( typeConverter, patterns.getContext()); // Populate `vector.*` store conversion patterns. The caller can choose - // to avoid emitting atomic operations and reduce it to load-modify-write + // to avoid emitting atomic operations and reduce it to read-modify-write // sequence for stores if it is known there are no thread contentions. patterns.insert(patterns.getContext(), disableAtomicRMW); } diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir index 84cae7d922b38..71143ad908895 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir @@ -1,9 +1,13 @@ -// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 disable-atomic-rmw=true" --cse --split-input-file %s | FileCheck %s // TODO: remove memref.alloc() in the tests to eliminate noises. // memref.alloc exists here because sub-byte vector data types such as i2 // are currently not supported as input arguments. +///---------------------------------------------------------------------------------------- +/// vector.store +///---------------------------------------------------------------------------------------- + func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { %0 = memref.alloc() : memref<3x3xi2> %c0 = arith.constant 0 : index @@ -11,9 +15,10 @@ func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2> return } -// In this example, emit two RMW stores without full-width store. -// Store bit [12:18), byte [1:2] to a 3-byte vector, both bytes are -// accessed partially. + +// Emit two non-atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)), +// into bytes [1:2] from a 3-byte output memref. Due to partial storing, +// both bytes are accessed partially through masking. // CHECK: func @vector_store_i2_const_index_two_partial_stores( // CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>) @@ -28,10 +33,10 @@ func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2> // CHECK: %[[LOAD:.+]] = vector.load -// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> -// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]] -// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] -// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]] +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[DOWNCAST]] +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[UPCAST]], %[[ALLOC]][%[[C1]]] // Part 2 RMW sequence // CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index @@ -90,11 +95,11 @@ func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) { // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]] // CHECK-SAME: {offsets = [0], strides = [1]} // CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> -// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]] -// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]] -// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST1]] -// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]] -// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]] +// CHECK: %[[LOAD2:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]] +// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] +// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]] +// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]] +// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[INDEX2]]] // ----- diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 89cb8e0bde875..6fc974200c6f3 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -369,10 +369,9 @@ func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { return } -// In this example, emit 2 atomic RMWs. -// -// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes: -// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out) +// Emit two atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)), +// into bytes [1:2] from a 3-byte output memref. Due to partial storing, +// both bytes are accessed partially through masking. // CHECK-LABEL: func @vector_store_i2_const_index_two_partial_stores( // CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>) From 0083eebe220a00c227faf145db1f2734a90b12cf Mon Sep 17 00:00:00 2001 From: Alan Li Date: Thu, 6 Feb 2025 20:39:56 +0000 Subject: [PATCH 8/8] update --- .../vector-emulate-narrow-type-unaligned-non-atomic.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir index 71143ad908895..1d6263535ae80 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir @@ -103,7 +103,7 @@ func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) { // ----- -func.func @vector_store_i2_one_partial_store(%arg0: vector<1xi2>) { +func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) { %0 = memref.alloc() : memref<4x1xi2> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -113,7 +113,7 @@ func.func @vector_store_i2_one_partial_store(%arg0: vector<1xi2>) { // in this test, only emit partial RMW store as the store is within one byte. -// CHECK: func @vector_store_i2_one_partial_store( +// CHECK: func @vector_store_i2_const_index_one_partial_store( // CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>) // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8> // CHECK: %[[C0:.+]] = arith.constant 0 : index