Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Appends patterns for emulating vector operations over narrow types with ops
/// over wider types.
/// over wider types. The `useAtomicWrites` indicates whether to use
/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
/// rmw sequence otherwise.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
RewritePatternSet &patterns, bool useAtomicWrites = true);

/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
Expand Down
61 changes: 55 additions & 6 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,27 @@ static void atomicStore(OpBuilder &builder, Location loc,
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
}

/// Generate a non-atomic read-modify-write sequence for subbyte storing.
/// It has similar logic to `atomicStore`, but without atomicity.
static void rmwStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");

auto oneElemVecType =
VectorType::get({1}, linearizedMemref.getType().getElementType());
Value origVecValue = builder.create<vector::LoadOp>(
loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
origVecValue);

Value maskedValue =
downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
oneElemVecType, mask, valueToStore, origVecValue);
builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
linearizedIndex);
}

/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
/// and insert it into an empty vector at `insertOffset`.
/// Inputs:
Expand Down Expand Up @@ -405,6 +426,10 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;

ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
: OpConversionPattern<vector::StoreOp>(context),
useAtomicWrites_(useAtomicWrites) {}

LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -555,8 +580,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);

atomicStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(value), frontMask.getResult());
subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(value), frontMask.getResult());
}

if (currentSourceIndex >= origElements) {
Expand Down Expand Up @@ -611,13 +636,31 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));

atomicStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(subWidthStorePart), backMask.getResult());
subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(subWidthStorePart),
backMask.getResult());
}

rewriter.eraseOp(op);
return success();
}

/// Store a subbyte-sized value to memory, with a mask. Depending on the
/// configuration, it could be an atomic store or a non-atomic RMW sequence.
template <typename... Args>
void subEmulatedWidthStore(Args &&...args) const {
static_assert(
std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
"`atomicStore` and `rmwStore` must have same signature, as per "
"the design to keep the code clean, which one to call is "
"determined by the `useAtomicWrites` flag.");
std::function<decltype(atomicStore)> storeFunc =
useAtomicWrites_ ? atomicStore : rmwStore;
storeFunc(std::forward<Args>(args)...);
}

private:
const bool useAtomicWrites_;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1930,12 +1973,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {

void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
RewritePatternSet &patterns, bool useAtomicWrites) {

// Populate `vector.*` conversion patterns.
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
// TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());

// Populate `vector.*` store conversion patterns. The caller can choose
// to avoid emitting atomic operations and reduce it to load-modify-write
// sequence for stores if it is known there are no thread contentions.
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
}

void vector::populateVectorNarrowTypeRewritePatterns(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s

// TODO: remove memref.alloc() in the tests to eliminate noises.
// memref.alloc exists here because sub-byte vector data types such as i2
// are currently not supported as input arguments.

func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
%0 = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
return
}
// In this example, emit two RMW stores without full-width store.
// Store bit [12:18), byte [1:2] to a 3-byte vector, both bytes are
// accessed partially.

// CHECK: func @vector_store_i2_const_index_two_partial_stores(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index

// Part 1 RMW sequence
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
// CHECK: %[[LOAD:.+]] = vector.load
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]

// Part 2 RMW sequence
// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
// CHECK: %[[LOAD2:.+]] = vector.load
// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]


// -----

func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) {
%0 = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
return
}

// In this example, emit two RMW stores and one full-width store.

// CHECK: func @vector_store_i2_two_partial_one_full_stores(
// CHECK-SAME: %[[ARG0:.+]]:
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
// CHECK-SAME: {offsets = [3], strides = [1]}
// First sub-width RMW:
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]

// Full-width store:
// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]

// Second sub-width RMW:
// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
// CHECK-SAME: {offsets = [0], strides = [1]}
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]]
// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST1]]
// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]]
// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]]

// -----

func.func @vector_store_i2_one_partial_store(%arg0: vector<1xi2>) {
%0 = memref.alloc() : memref<4x1xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
return
}

// in this test, only emit partial RMW store as the store is within one byte.

// CHECK: func @vector_store_i2_one_partial_store(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
8 changes: 7 additions & 1 deletion mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass

arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
atomicStore);

if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
Expand All @@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass
*this, "skip-memref-type-conversion",
llvm::cl::desc("disable memref type conversion (to test failures)"),
llvm::cl::init(false)};

Option<bool> atomicStore{
*this, "atomic-store",
llvm::cl::desc("use atomic store instead of load-modify-write"),
llvm::cl::init(true)};
};
} // namespace

Expand Down