Skip to content

Commit cc7b19b

Browse files
committed
Implement VectorLoadOp
1 parent 2c31325 commit cc7b19b

File tree

2 files changed

+143
-26
lines changed

2 files changed

+143
-26
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/OpDefinition.h"
2122
#include "mlir/IR/TypeUtilities.h"
2223
#include "mlir/IR/Value.h"
2324
#include "mlir/Transforms/DialectConversion.h"
@@ -149,6 +150,61 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
149150
dest, offsets, strides);
150151
}
151152

153+
static void dynamicallyExtractElementsToVector(
154+
RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
155+
Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
156+
/*
157+
// Create affine maps for the lower and upper bounds
158+
AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
159+
AffineMap upperBoundMap =
160+
AffineMap::getConstantMap(loopSize, rewriter.getContext());
161+
162+
auto forLoop = rewriter.create<affine::AffineForOp>(
163+
loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
164+
ArrayRef<Value>(destVec));
165+
166+
OpBuilder builder =
167+
OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
168+
169+
auto iv = forLoop.getInductionVar();
170+
171+
auto loopDestVec = forLoop.getRegionIterArgs()[0];
172+
auto extractLoc = builder.create<arith::AddIOp>(
173+
loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
174+
auto extractElemOp = builder.create<vector::ExtractElementOp>(
175+
loc, elemType, srcVec, extractLoc);
176+
auto insertElemOp = builder.create<vector::InsertElementOp>(
177+
loc, extractElemOp, loopDestVec, iv);
178+
builder.create<affine::AffineYieldOp>(loc,
179+
ValueRange{insertElemOp->getResult(0)});
180+
return forLoop->getResult(0);
181+
*/
182+
for (int i = 0; i < loopSize; ++i) {
183+
Value extractLoc;
184+
if (i == 0) {
185+
extractLoc = srcOffsetVar.dyn_cast<Value>();
186+
} else {
187+
extractLoc = rewriter.create<arith::AddIOp>(
188+
loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(),
189+
rewriter.create<arith::ConstantIndexOp>(loc, i));
190+
}
191+
auto extractOp =
192+
rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
193+
rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
194+
}
195+
}
196+
197+
static TypedValue<VectorType>
198+
emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
199+
Value base, OpFoldResult linearizedIndices, int64_t numBytes,
200+
int64_t scale, Type oldElememtType, Type newElementType) {
201+
auto newLoad = rewriter.create<vector::LoadOp>(
202+
loc, VectorType::get(numBytes, newElementType), base,
203+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
204+
return rewriter.create<vector::BitCastOp>(
205+
loc, VectorType::get(numBytes * scale, oldElememtType), newLoad);
206+
};
207+
152208
namespace {
153209

154210
//===----------------------------------------------------------------------===//
@@ -380,26 +436,29 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
380436
? getConstantIntValue(linearizedInfo.intraDataOffset)
381437
: 0;
382438

383-
if (!foldedIntraVectorOffset) {
384-
// unimplemented case for dynamic intra vector offset
385-
return failure();
386-
}
387-
439+
// always load enough elements which can cover the original elements
440+
auto maxintraVectorOffset =
441+
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
388442
auto numElements =
389-
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
390-
auto newLoad = rewriter.create<vector::LoadOp>(
391-
loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
392-
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
393-
394-
Value result = rewriter.create<vector::BitCastOp>(
395-
loc, VectorType::get(numElements * scale, oldElementType), newLoad);
443+
llvm::divideCeil(maxintraVectorOffset + origElements, scale);
444+
Value result =
445+
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
446+
numElements, scale, oldElementType, newElementType);
396447

397-
if (isUnalignedEmulation) {
398-
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
399-
*foldedIntraVectorOffset, origElements);
448+
if (foldedIntraVectorOffset) {
449+
if (isUnalignedEmulation) {
450+
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
451+
*foldedIntraVectorOffset, origElements);
452+
}
453+
rewriter.replaceOp(op, result);
454+
} else {
455+
auto resultVector = rewriter.create<arith::ConstantOp>(
456+
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
457+
dynamicallyExtractElementsToVector(
458+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
459+
linearizedInfo.intraVectorOffset, origElements);
460+
rewriter.replaceOp(op, resultVector);
400461
}
401-
402-
rewriter.replaceOp(op, result);
403462
return success();
404463
}
405464
};
@@ -604,13 +663,10 @@ struct ConvertVectorTransferRead final
604663
? getConstantIntValue(linearizedInfo.intraDataOffset)
605664
: 0;
606665

607-
if (!foldedIntraVectorOffset) {
608-
// unimplemented case for dynamic inra-vector offset
609-
return failure();
610-
}
611-
666+
auto maxIntraVectorOffset =
667+
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
612668
auto numElements =
613-
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
669+
llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
614670

615671
auto newRead = rewriter.create<vector::TransferReadOp>(
616672
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -621,9 +677,17 @@ struct ConvertVectorTransferRead final
621677
loc, VectorType::get(numElements * scale, oldElementType), newRead);
622678

623679
Value result = bitCast->getResult(0);
624-
if (isUnalignedEmulation) {
625-
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
626-
*foldedIntraVectorOffset, origElements);
680+
if (foldedIntraVectorOffset) {
681+
if (isUnalignedEmulation) {
682+
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
683+
*foldedIntraVectorOffset, origElements);
684+
}
685+
} else {
686+
result = rewriter.create<arith::ConstantOp>(
687+
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
688+
dynamicallyExtractElementsToVector(rewriter, loc, bitCast, result,
689+
linearizedInfo.intraVectorOffset,
690+
origElements);
627691
}
628692
rewriter.replaceOp(op, result);
629693

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
2+
3+
// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
4+
// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
5+
func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
6+
%0 = memref.alloc() : memref<3x3xi2>
7+
%c0 = arith.constant 0 : index
8+
%c2 = arith.constant 2 : index
9+
%cst = arith.constant dense<0> : vector<3x3xi2>
10+
%1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
11+
%2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
12+
return %2 : vector<3x3xi2>
13+
}
14+
15+
// CHECK: func @vector_load_i2
16+
// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
17+
// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
18+
// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
19+
// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
20+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
21+
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
22+
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
23+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
24+
// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
25+
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
26+
// CHECK: %[[C2:.+]] = arith.constant 2 : index
27+
// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
28+
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
29+
30+
//-----
31+
32+
func.func @vector_transfer_read_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
33+
%0 = memref.alloc() : memref<3x3xi2>
34+
%c0i2 = arith.constant 0 : i2
35+
%1 = vector.transfer_read %0[%arg1, %arg2], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
36+
return %1 : vector<3xi2>
37+
}
38+
39+
// CHECK: func @vector_transfer_read_i2
40+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
41+
// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
42+
// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
43+
// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
44+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
45+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
46+
// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
47+
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
48+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
49+
// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
50+
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
51+
// CHECK: %[[C2:.+]] = arith.constant 2 : index
52+
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
53+
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>

0 commit comments

Comments
 (0)