Skip to content

Commit 9b6e6c8

Browse files
author
Xu, Xiaohui1
committed
fix reduce loop indice
1 parent 4c901c4 commit 9b6e6c8

File tree

5 files changed

+85
-27
lines changed

5 files changed

+85
-27
lines changed

include/gc/Analysis/VectorBasedFusionAnalysis.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ class TypeHelper {
4848
int getDataTypeValidSteps(VectorType type);
4949
/// get vector \param type an even for loop step
5050
int generateValidSteps(int steps, VectorType type);
51-
/// get vector \param type max simd length according to hardware information
51+
/// get vector \param type an even for loop step when shape dimension is
52+
/// shapeDim
53+
int generateValidSteps(int steps, VectorType type, int shapeDim);
54+
/// get vector \param type max simd length according to hardware
55+
/// information
5256
int getDataTypeMAXSIMDLength(VectorType type);
5357
/// get operation's vector type
5458
VectorType getVectorzedType(Operation *op, uint32_t loopStep = 0);

lib/gc/Analysis/VectorBasedFusionAnalysis.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,14 +374,24 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) {
374374
return VectorType::get({loopStep}, vectorizedType.getElementType());
375375
}
376376

377+
int TypeHelper::generateValidSteps(int steps, VectorType type, int shapeDim) {
378+
if (shapeDim & 1)
379+
return 1;
380+
auto typebits = type.getElementTypeBitWidth();
381+
if (shapeDim >= steps)
382+
return steps * typebits >= 128 ? steps : 1;
383+
int evenStep = getNearestVectorStep(shapeDim);
384+
return evenStep * typebits >= 128 ? evenStep : 1;
385+
}
386+
377387
int TypeHelper::generateValidSteps(int steps, VectorType type) {
378388
// TODO: support odd shape using mask load store
379389
if (type.getShape().back() & 1)
380390
return 1;
391+
auto typebits = type.getElementTypeBitWidth();
381392
if (type.getShape().back() >= steps)
382-
return steps;
393+
return steps * typebits >= 128 ? steps : 1;
383394
int evenStep = getNearestVectorStep(type.getShape().back());
384-
auto typebits = type.getElementTypeBitWidth();
385395
return evenStep * typebits >= 128 ? evenStep : 1;
386396
}
387397

lib/gc/Transforms/CPUPhysicalRegisterPass.cpp

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,6 @@ void updateCurrentArgsStatus(ValueRange loopState, const size_t loopStateIdx,
545545
DenseMap<Value, Value> &nextOriginalOperandMap,
546546
DenseMap<Value, Value> &nextOperandOriginalMap) {
547547
Value currentArgs = loopState[loopStateIdx];
548-
if (currentArgs.getType() != originalValue.getType()) {
549-
llvm::outs() << loopStateIdx << ","
550-
<< "\n";
551-
currentArgs.dump();
552-
llvm::llvm_unreachable_internal("Type not equal.");
553-
}
554548
nextAnchorArgs.emplace_back(currentArgs);
555549
nextAnchorArgsIdxMap[currentArgs] = nextAnchorArgs.size() - 1;
556550
nextOriginalOperandMap[originalValue] = currentArgs;
@@ -740,6 +734,36 @@ void updateLoopArgsData(Value val, Value originalVal,
740734
originalOperandLoopArgsMap[originalVal] = val;
741735
}
742736

737+
void LoopGeneratorImpl::rectifyParallelIndice(
738+
GenerateLoopHelper &loopHelperParam, OpBuilder &b, Location loc) {
739+
MultiReductionCanonicalizer rdCanonicalizer =
740+
getMultiRdCanonicalizers()[loopHelperParam.groupIdx];
741+
auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0];
742+
SmallVector<int64_t, 4> &reductionAxis = rdCanonicalizer.getReductionAxis();
743+
744+
// rectify indice of read from source operand
745+
auto sourceReadOp =
746+
multireductionOp.getSource().getDefiningOp<vector::TransferReadOp>();
747+
if (!sourceReadOp)
748+
return;
749+
750+
AffineExpr outterParallel, innerParallel;
751+
bindDims(multireductionOp->getContext(), outterParallel, innerParallel);
752+
753+
Value op =
754+
loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() -
755+
reductionAxis.size() - 2];
756+
Value ip =
757+
loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() -
758+
reductionAxis.size() - 1];
759+
Value newIndice = b.createOrFold<affine::AffineApplyOp>(
760+
loc, (outterParallel + innerParallel), ValueRange{op, ip});
761+
int parallelSize = rdCanonicalizer.getParallelAxis().size();
762+
int readIndiceOffset =
763+
1 + rdCanonicalizer.getParallelAxis()[parallelSize - 1];
764+
sourceReadOp->setOperand(readIndiceOffset, newIndice);
765+
}
766+
743767
scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop(
744768
OpBuilder &opBuilder, const size_t reductionIdx,
745769
GenerateLoopHelper &loopHelperParam) {
@@ -755,18 +779,22 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop(
755779

756780
const auto loc = multireductionOp->getLoc();
757781
SmallVector<int64_t, 4> &reductionAxis = rdCanonicalizer.getReductionAxis();
758-
bool lastDimReduction = rdCanonicalizer.hasLastDimReduction();
759782
VectorType vectorType = rdCanonicalizer.getSourceType();
760-
const int loopStep =
761-
getVectorBasedFusion().getGroupMaxSteps()[loopHelperParam.groupIdx];
783+
auto tpHelper = fusionStrategy.getTypeHelper();
784+
785+
int loopStep = tpHelper.generateValidSteps(
786+
fusionStrategy.getTypeHelper().getDataTypeMAXSIMDLength(vectorType),
787+
vectorType, vectorType.getShape()[reductionAxis[reductionIdx]]);
788+
bool isLastDimReduction = rdCanonicalizer.getHasLastDimReduction();
789+
loopStep = (reductionIdx == reductionAxis.size() - 1 && isLastDimReduction)
790+
? loopStep
791+
: 1;
792+
762793
func::FuncOp func = fusionStrategy.getFunction();
763794
IRRewriter rewriterOfFunc(func);
764795

765796
Value zero = makeIndexArithConstantOp(opBuilder, loc, 0);
766-
Value forSteps = makeIndexArithConstantOp(
767-
opBuilder, loc,
768-
(reductionIdx == reductionAxis.size() - 1 && lastDimReduction) ? loopStep
769-
: 1);
797+
Value forSteps = makeIndexArithConstantOp(opBuilder, loc, loopStep);
770798
Value numIter = makeIndexArithConstantOp(
771799
opBuilder, loc, vectorType.getShape()[reductionAxis[reductionIdx]]);
772800
scf::ForOp forOp = opBuilder.create<scf::ForOp>(
@@ -868,9 +896,12 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop(
868896
}
869897

870898
rewriteOperationAsVectorize(b, loopHelperParam.groupIdx,
871-
&movingOperation);
899+
&movingOperation,
900+
isLastDimReduction ? loopStep : 0);
872901
loopHelperParam.loopIterArgs = loopState;
873902
moveOperationsToCurrentForBody(b, movingOperation, loopHelperParam);
903+
if (isLastDimReduction)
904+
rectifyParallelIndice(loopHelperParam, b, loc);
874905
loopHelperParam.movedOps = &movingOperation;
875906
loopHelperParam.candidateOps = &opQueue;
876907

@@ -1058,11 +1089,16 @@ scf::ForOp LoopGeneratorImpl::parallelAxisGenerateForLoop(
10581089
// get accumualte value
10591090
Attribute initValueAttr;
10601091
getReductionInitAttr(multiReductionOp, initValueAttr);
1061-
1092+
SmallVector<int64_t, 4> &reductionAxis =
1093+
rdCanonicalizer.getReductionAxis();
1094+
TypeHelper tpHelper = fusionStrategy.getTypeHelper();
1095+
int loopStep = tpHelper.generateValidSteps(
1096+
tpHelper.getDataTypeMAXSIMDLength(vectorType), vectorType,
1097+
vectorType.getShape()[reductionAxis[reductionAxis.size() - 1]]);
10621098
auto accVal = b.create<arith::ConstantOp>(
10631099
loc, DenseElementsAttr::get(
10641100
fusionStrategy.getTypeHelper().getVectorzedType(
1065-
multiReductionOp, dimSize),
1101+
multiReductionOp, loopStep),
10661102
{initValueAttr}));
10671103

10681104
// put accumulte val at first for loop args
@@ -1247,14 +1283,14 @@ void LoopGeneratorImpl::rearrageMultiReductionIR(
12471283
DenseMap<size_t, size_t> varLoopIdxMap;
12481284
VectorType groupVector =
12491285
getVectorBasedFusion().getGroupBiggestRankVectorType()[grpIdx];
1250-
for (size_t i = 0; i < parallelAxis.size(); i++) {
1286+
for (size_t i = 0; i < parallelAxis.size(); i++)
12511287
varLoopIdxMap[parallelAxis[i]] = i;
1252-
}
1288+
12531289
size_t offset = rdCanonicalizer.hasLastDimReduction() ? 1 : 0;
12541290
for (size_t i = parallelAxis.size() + offset;
1255-
i < groupVector.getRank() + offset; i++) {
1291+
i < groupVector.getRank() + offset; i++)
12561292
varLoopIdxMap[reductionAxis[i - parallelAxis.size() - offset]] = i;
1257-
}
1293+
12581294
while (!tmpSourceQ.empty()) {
12591295
auto *curOp = tmpSourceQ.front();
12601296
tmpSourceQ.pop();
@@ -2313,7 +2349,8 @@ void ForLoopGenerator::createNewConstantOp(
23132349

23142350
/// Rewrite the operations in the group to vectorized form.
23152351
void ForLoopGenerator::rewriteOperationAsVectorize(
2316-
OpBuilder &rewriter, size_t groupId, const std::queue<Operation *> *queue) {
2352+
OpBuilder &rewriter, size_t groupId, const std::queue<Operation *> *queue,
2353+
const size_t vectorizeStep) {
23172354
const std::queue<Operation *> groupOps =
23182355
!queue ? getVectorBasedFusion().getOpGroups()[groupId] : *queue;
23192356

@@ -2322,7 +2359,9 @@ void ForLoopGenerator::rewriteOperationAsVectorize(
23222359
DenseMap<Operation *, AffineMap> &opPermuationMap =
23232360
getVectorBasedFusion().getOpPermuationMap();
23242361
std::queue<Operation *> transformQueue(groupOps);
2325-
size_t groupSteps = getVectorBasedFusion().getGroupMaxSteps()[groupId];
2362+
size_t groupSteps = vectorizeStep == 0
2363+
? getVectorBasedFusion().getGroupMaxSteps()[groupId]
2364+
: vectorizeStep;
23262365

23272366
while (!transformQueue.empty()) {
23282367
Operation *op = transformQueue.front();

lib/gc/Transforms/TilingVector.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ class ForLoopGenerator {
354354
/// rewrite operation as vectorize IR in current operation group
355355
void
356356
rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId,
357-
const std::queue<Operation *> *queue = nullptr);
357+
const std::queue<Operation *> *queue = nullptr,
358+
const size_t vectorizeStep = 0);
358359
/// Reimplementation of writing a tensor from a constant of denseElementattr.
359360
void createNewConstantOp(Operation *srcOp,
360361
vector::TransferWriteOp *transferWriteOp,
@@ -489,6 +490,8 @@ class LoopGeneratorImpl : public ForLoopGenerator {
489490
scf::ForOp reductionAxisGenerateForLoop(OpBuilder &opBuilder,
490491
const size_t reductionIdx,
491492
GenerateLoopHelper &loopHelperParam);
493+
void rectifyParallelIndice(GenerateLoopHelper &loopHelperParam, OpBuilder &b,
494+
Location loc);
492495
/// reduction operation parallel axis for loop
493496
scf::ForOp parallelAxisGenerateForLoop(OpBuilder &opBuilder,
494497
GenerateLoopHelper &loopHelperParam);

test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// CHECK-DAG: #[[map6:.*]] = affine_map<(d0, d1) -> (d0 floordiv 16 + d1 floordiv 16)>
1111
// CHECK-DAG: #[[map7:.*]] = affine_map<()[s0, s1] -> (s0 * 32 + s1)>
1212
// CHECK-DAG: #[[map8:.*]] = affine_map<()[s0, s1] -> (s0 * 16 + s1)>
13+
// CHECK-DAG: #[[map9:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
1314

1415

1516

@@ -619,7 +620,8 @@ func.func @reduce_fuse_test12(%input: tensor<16x32x64xf32>,
619620
// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[READ1]]) -> (vector<16xf32>)
620621
// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[CST]]) -> (vector<16xf32>)
621622
// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (vector<16xf32>)
622-
// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}[%[[arg2]], %[[arg6]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32>
623+
// CHECK: %[[APPLY0:.*]] = affine.apply #[[map9]](%[[arg2]], %[[arg4]])
624+
// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}[%[[APPLY0]], %[[arg6]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32>
623625
// CHECK: %[[ADD0:.*]] = arith.addf %[[READ2]], %[[arg9]] : vector<16xf32>
624626
// CHECK: %[[REDUCTION:.*]] = vector.reduction <add>, {{.*}} : vector<16xf32> into f32
625627
// CHECK: %[[INSERT:.*]] = vector.insert %[[REDUCTION]], %[[arg5]] [%[[arg4]]] : f32 into vector<16xf32>

0 commit comments

Comments
 (0)