Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Appends patterns for emulating vector operations over narrow types with ops
/// over wider types.
/// over wider types. The `useAtomicWrites` indicates whether to use
/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
/// rmw sequence otherwise.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
RewritePatternSet &patterns, bool useAtomicWrites = true);

/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
Expand Down
56 changes: 47 additions & 9 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
///
/// Result:
/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
static void atomicStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value storeIdx,
VectorValue valueToStore, Value mask) {
static void atomicRMWStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value storeIdx,
VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");

// Create an atomic load-modify-write region using
Expand All @@ -363,6 +363,28 @@ static void atomicStore(OpBuilder &builder, Location loc,
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
}

/// Generate a non-atomic read-modify-write sequence for subbyte storing.
/// It has similar logic to `atomicRMWStore`, but without atomicity.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an answer yet, but we may need to replace sub-byte with emulatedType - not all "emulated" types within this file are sub-bytes. We also use e.g. i32 to emulate i8:

func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
return %1 : vector<4xi8>
}
// Expect no conversions, i8 is supported.
// CHECK: func @vector_load_i8(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
// CHECK-NEXT: [[L:%.+]] = vector.load %[[ALLOC]][%[[ARG0]], %[[ARG1]]] : memref<3x4xi8>, vector<4xi8>
// CHECK-NEXT: return
// CHECK32: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
// CHECK32: func @vector_load_i8(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[VECLOAD:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VECLOAD]] : vector<1xi32> to vector<4xi8>
// CHECK32: return %[[VEC_I4]]

Either we refrain from using sub-byte or force this logic to only be available for sub-bytes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. I can see there are many places still using sub-byte descriptions. Let's update those to be consistent later.

static void nonAtomicRMWStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref,
Value linearizedIndex, VectorValue valueToStore,
Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");

auto oneElemVecType =
VectorType::get({1}, linearizedMemref.getType().getElementType());
Value origVecValue = builder.create<vector::LoadOp>(
loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
origVecValue);

Value maskedValue =
downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
oneElemVecType, mask, valueToStore, origVecValue);
builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
linearizedIndex);
}

/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
/// and insert it into an empty vector at `insertOffset`.
/// Inputs:
Expand Down Expand Up @@ -405,6 +427,10 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;

ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
: OpConversionPattern<vector::StoreOp>(context),
useAtomicWrites_(useAtomicWrites) {}

LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -555,8 +581,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);

atomicStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(value), frontMask.getResult());
auto storeFunc = useAtomicWrites_ ? atomicRMWStore : nonAtomicRMWStore;

storeFunc(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(value), frontMask.getResult());
}

if (currentSourceIndex >= origElements) {
Expand Down Expand Up @@ -611,13 +639,17 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));

atomicStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(subWidthStorePart), backMask.getResult());
subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(subWidthStorePart),
backMask.getResult());
}

rewriter.eraseOp(op);
return success();
}

private:
const bool useAtomicWrites_;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1930,12 +1962,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {

void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
RewritePatternSet &patterns, bool useAtomicWrites) {

// Populate `vector.*` conversion patterns.
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
// TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());

// Populate `vector.*` store conversion patterns. The caller can choose
// to avoid emitting atomic operations and reduce it to load-modify-write
// sequence for stores if it is known there are no thread contentions.
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
}

void vector::populateVectorNarrowTypeRewritePatterns(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s

// TODO: remove memref.alloc() in the tests to eliminate noises.
// memref.alloc exists here because sub-byte vector data types such as i2
// are currently not supported as input arguments.

func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
%0 = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
return
}
// In this example, emit two RMW stores without full-width store.
// Store bit [12:18), byte [1:2] to a 3-byte vector, both bytes are
// accessed partially.

// CHECK: func @vector_store_i2_const_index_two_partial_stores(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index

// Part 1 RMW sequence
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
// CHECK: %[[LOAD:.+]] = vector.load
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]

// Part 2 RMW sequence
// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
// CHECK: %[[LOAD2:.+]] = vector.load
// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]


// -----

func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) {
%0 = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
return
}

// In this example, emit two RMW stores and one full-width store.

// CHECK: func @vector_store_i2_two_partial_one_full_stores(
// CHECK-SAME: %[[ARG0:.+]]:
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
// CHECK-SAME: {offsets = [3], strides = [1]}
// First sub-width RMW:
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]

// Full-width store:
// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]

// Second sub-width RMW:
// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
// CHECK-SAME: {offsets = [0], strides = [1]}
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]]
// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST1]]
// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]]
// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]]

// -----

func.func @vector_store_i2_one_partial_store(%arg0: vector<1xi2>) {
%0 = memref.alloc() : memref<4x1xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
return
}

// in this test, only emit partial RMW store as the store is within one byte.

// CHECK: func @vector_store_i2_one_partial_store(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
8 changes: 7 additions & 1 deletion mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass

arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
atomicStore);

if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
Expand All @@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass
*this, "skip-memref-type-conversion",
llvm::cl::desc("disable memref type conversion (to test failures)"),
llvm::cl::init(false)};

Option<bool> atomicStore{
*this, "atomic-store",
llvm::cl::desc("use atomic store instead of load-modify-write"),
llvm::cl::init(true)};
};
} // namespace

Expand Down