Skip to content
Merged
125 changes: 93 additions & 32 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -37,16 +38,17 @@ using namespace mlir;

/// Returns a compressed mask. The mask value is set only if any mask is present
/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
/// equals to 2, the following mask:
/// equals to 1 (intraDataOffset strictly smaller than scale), the following
/// mask:
///
/// %mask = [1, 1, 1, 0, 0, 0]
/// %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded with number of `intraDataOffset` zeros:
/// %mask = [0, 0, 1, 1, 1, 0, 0, 0]
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
///
/// then it will return the following new compressed mask:
///
/// %mask = [0, 1, 1, 0]
/// %mask = [1, 1, 0, 0]
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
Expand Down Expand Up @@ -75,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
shape.back() = numElements;
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
if (createMaskOp) {
// TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
if (intraDataOffset != 0)
return failure();
OperandRange maskOperands = createMaskOp.getOperands();
size_t numMaskOperands = maskOperands.size();
AffineExpr s0;
Expand Down Expand Up @@ -129,9 +128,17 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return newMask;
}

/// A wrapper function for emitting `vector.extract_strided_slice`. The vector
/// has to be of 1-D shape.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Rather than saying that the vector has to be 1-D, why not say "Extracts 1-D subvector from a 1-D vector". This way, the intent becomes clearer and the requirements are implied ;-)

Just so that you don't have to guess what I had in mind:

/// A wrapper function to extract a 1-D subvector from the 1-D source vector.

Also, could you remind my why use vector.extract_strided_slice rather than vector.extract?

Btw, none of this is a blocker for this PR. These are nice-to-have improvements, thanks!

Copy link
Member Author

@lialan lialan Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

if I understand correctly, vector.extract can extract a single element or the whole innermost dimension, but if we want to operate on a certain part of the inner most dimension then we will have to use vector.extract_strided_slice?

Too bad the docs are not super formal so this is how I read it.

static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
VectorType extractType, Value vector,
int64_t frontOffset, int64_t subvecSize) {
auto vectorType = dyn_cast<VectorType>(vector.getType());
assert(vectorType && "expected vector type");
assert(vectorType.getShape().size() == 1 && "expected 1-D vector type");
assert(extractType.getShape().size() == 1 &&
"extractType must be 1-D vector type");

auto offsets = rewriter.getI64ArrayAttr({frontOffset});
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
auto strides = rewriter.getI64ArrayAttr({1});
Expand All @@ -141,14 +148,61 @@ static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
->getResult(0);
}

/// A wrapper function for emitting `vector.insert_strided_slice`. The source
/// and dest vectors must be of 1-D shape.
static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
Value src, Value dest, int64_t offset) {
auto srcType = dyn_cast<VectorType>(src.getType());
assert(srcType && "expected vector type");
assert(srcType.getShape().size() == 1 && "expected 1-D vector type");
auto destType = dyn_cast<VectorType>(dest.getType());
assert(destType && "expected vector type");
assert(destType.getShape().size() == 1 && "expected 1-D vector type");

auto offsets = rewriter.getI64ArrayAttr({offset});
auto strides = rewriter.getI64ArrayAttr({1});
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
dest, offsets, strides);
}

/// Extracts `lengthSubvec` elements from `srcVec` into `destVec` starting at
/// the offset specified by `srcOffsetVar`. Use this function when
/// `srcOffsetVar` is not a constant, making it impossible to use
/// vector.extract_strided_slice, as it requires constant offsets.
static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
TypedValue<VectorType> source,
Value dest, OpFoldResult offset,
int64_t numElementsToExtract) {
for (int i = 0; i < numElementsToExtract; ++i) {
Value extractLoc =
(i == 0) ? offset.dyn_cast<Value>()
: rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
rewriter.create<arith::ConstantIndexOp>(loc, i));
auto extractOp =
rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
}
return dest;
}

/// Load `numLoadedElements` of `newElementType` from `base` at
/// `linearizedIndices`, then bitcast the result into a vector of
/// `oldElementType`.
static TypedValue<VectorType>
emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
Value base, OpFoldResult linearizedIndices,
int64_t numElementsToLoad, Type oldElememtType,
Type newElementType) {
auto scale = newElementType.getIntOrFloatBitWidth() /
oldElememtType.getIntOrFloatBitWidth();
auto newLoad = rewriter.create<vector::LoadOp>(
loc, VectorType::get(numElementsToLoad, newElementType), base,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElementsToLoad * scale, oldElememtType), newLoad);
};

namespace {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -380,25 +434,27 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;

if (!foldedIntraVectorOffset) {
// unimplemented case for dynamic intra vector offset
return failure();
}

// always load enough elements which can cover the original elements
auto maxintraDataOffset =
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
auto numElements =
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
auto newLoad = rewriter.create<vector::LoadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));

Value result = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements * scale, oldElementType), newLoad);
llvm::divideCeil(maxintraDataOffset + origElements, scale);
Value result =
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
numElements, oldElementType, newElementType);

if (isUnalignedEmulation) {
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
if (foldedIntraVectorOffset) {
if (isUnalignedEmulation) {
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
} else {
auto resultVector = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
}

rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -604,13 +660,10 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;

if (!foldedIntraVectorOffset) {
// unimplemented case for dynamic inra-vector offset
return failure();
}

auto maxIntraVectorOffset =
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
auto numElements =
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
llvm::divideCeil(maxIntraVectorOffset + origElements, scale);

auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
Expand All @@ -621,9 +674,17 @@ struct ConvertVectorTransferRead final
loc, VectorType::get(numElements * scale, oldElementType), newRead);

Value result = bitCast->getResult(0);
if (isUnalignedEmulation) {
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
if (foldedIntraVectorOffset) {
if (isUnalignedEmulation) {
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
} else {
auto zeros = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
linearizedInfo.intraDataOffset,
origElements);
}
rewriter.replaceOp(op, result);

Expand Down
104 changes: 80 additions & 24 deletions mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s

func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%cst = arith.constant dense<0> : vector<3x3xi2>
%1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
%2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
return %2 : vector<3x3xi2>
// TODO: remove memref.alloc() in the tests to eliminate noises.
// memref.alloc exists here because sub-byte vector data types such as i2
// are currently not supported as input arguments.

// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>

func.func @vector_load_i2() -> vector<3x3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%cst = arith.constant dense<0> : vector<3x3xi2>
%1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
%2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
return %2 : vector<3x3xi2>
}

// CHECK: func @vector_load_i2
Expand All @@ -20,12 +27,12 @@ func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
//-----

func.func @vector_transfer_read_i2() -> vector<3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%c0i2 = arith.constant 0 : i2
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%1 = vector.transfer_read %0[%c2, %c0], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
return %1 : vector<3xi2>
%0 = memref.alloc() : memref<3x3xi2>
%pad = arith.constant 0 : i2
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%1 = vector.transfer_read %0[%c2, %c0], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
return %1 : vector<3xi2>
}

// CHECK: func @vector_transfer_read_i2
Expand All @@ -38,15 +45,15 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
//-----

func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
%0 = memref.alloc() : memref<3x5xi2>
%cst = arith.constant dense<0> : vector<3x5xi2>
%mask = vector.constant_mask [3] : vector<5xi1>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
%2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
return %2 : vector<3x5xi2>
%0 = memref.alloc() : memref<3x5xi2>
%cst = arith.constant dense<0> : vector<3x5xi2>
%mask = vector.constant_mask [3] : vector<5xi1>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
%2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
return %2 : vector<3x5xi2>
}

// CHECK: func @vector_cst_maskedload_i2
Expand All @@ -64,4 +71,53 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi1> into vector<8xi1>
// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>

//-----

func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%cst = arith.constant dense<0> : vector<3x3xi2>
%1 = vector.load %0[%idx1, %idx2] : memref<3x3xi2>, vector<3xi2>
return %1 : vector<3xi2>
}

// CHECK: func @vector_load_i2_dynamic_indexing
// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>

//-----

func.func @vector_transfer_read_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%pad = arith.constant 0 : i2
%1 = vector.transfer_read %0[%idx1, %idx2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
return %1 : vector<3xi2>
}

// CHECK: func @vector_transfer_read_i2_dynamic_indexing
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
Loading