Skip to content

Commit 919dd11

Browse files
author
Xu, Xiaohui1
committed
update reduction rectify indice code
1 parent 9a62f0b commit 919dd11

File tree

4 files changed

+61
-26
lines changed

4 files changed

+61
-26
lines changed

include/gc/Transforms/Utils/VectorUtils.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/ADT/TypeSwitch.h"
1919
#include "llvm/Support/Debug.h"
2020
#include <limits>
21+
#include <queue>
2122
#include <stdint.h>
2223
#include <variant>
2324

@@ -151,6 +152,36 @@ T getInitValForReduce(vector::CombiningKind kind, Type t) {
151152
return result;
152153
}
153154

155+
template <typename TARGETOP>
156+
void getSameBlockTargetOp(Operation *op,
157+
std::queue<Operation *> &candidateOps) {
158+
if (isa<TARGETOP>(op)) {
159+
candidateOps.push(op);
160+
return;
161+
}
162+
auto getSameBlockSrcOp = [](Operation *trackSrcOp,
163+
std::queue<Operation *> &trackOps,
164+
std::queue<Operation *> &candidateOps) {
165+
for (Value opd : trackSrcOp->getOperands()) {
166+
if (isa<BlockArgument>(opd) or
167+
opd.getDefiningOp()->getBlock() != trackSrcOp->getBlock())
168+
continue;
169+
if (isa<TARGETOP>(opd.getDefiningOp()))
170+
candidateOps.push(opd.getDefiningOp());
171+
else
172+
trackOps.push(opd.getDefiningOp());
173+
}
174+
};
175+
176+
std::queue<Operation *> trackOps;
177+
getSameBlockSrcOp(op, trackOps, candidateOps);
178+
while (not trackOps.empty()) {
179+
Operation *cadidateOp = trackOps.front();
180+
trackOps.pop();
181+
getSameBlockSrcOp(cadidateOp, trackOps, candidateOps);
182+
}
183+
}
184+
154185
} // namespace gc
155186
} // namespace mlir
156187

lib/gc/Transforms/CPUPhysicalRegisterPass.cpp

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -735,33 +735,37 @@ void updateLoopArgsData(Value val, Value originalVal,
735735
}
736736

737737
void LoopGeneratorImpl::rectifyParallelIndice(
738-
GenerateLoopHelper &loopHelperParam, OpBuilder &b, Location loc) {
738+
GenerateLoopHelper &loopHelperParam, Location loc) {
739739
MultiReductionCanonicalizer rdCanonicalizer =
740740
getMultiRdCanonicalizers()[loopHelperParam.groupIdx];
741741
auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0];
742742
SmallVector<int64_t, 4> &reductionAxis = rdCanonicalizer.getReductionAxis();
743743

744744
// rectify indice of read from source operand
745-
auto sourceReadOp =
746-
multireductionOp.getSource().getDefiningOp<vector::TransferReadOp>();
747-
if (!sourceReadOp)
748-
return;
749-
750-
AffineExpr outterParallel, innerParallel;
751-
bindDims(multireductionOp->getContext(), outterParallel, innerParallel);
752-
753-
Value op =
754-
loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() -
755-
reductionAxis.size() - 2];
756-
Value ip =
757-
loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() -
758-
reductionAxis.size() - 1];
759-
Value newIndice = b.createOrFold<affine::AffineApplyOp>(
760-
loc, (outterParallel + innerParallel), ValueRange{op, ip});
761-
int parallelSize = rdCanonicalizer.getParallelAxis().size();
762-
int readIndiceOffset =
763-
1 + rdCanonicalizer.getParallelAxis()[parallelSize - 1];
764-
sourceReadOp->setOperand(readIndiceOffset, newIndice);
745+
std::queue<Operation *> candidateOps;
746+
getSameBlockTargetOp<vector::TransferReadOp>(
747+
multireductionOp.getSource().getDefiningOp(), candidateOps);
748+
while (not candidateOps.empty()) {
749+
auto sourceReadOp = candidateOps.front();
750+
candidateOps.pop();
751+
IRRewriter rewriter(sourceReadOp);
752+
rewriter.setInsertionPoint(sourceReadOp);
753+
AffineExpr outterParallel, innerParallel;
754+
bindDims(multireductionOp->getContext(), outterParallel, innerParallel);
755+
756+
Value op =
757+
loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() -
758+
reductionAxis.size() - 2];
759+
Value ip =
760+
loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() -
761+
reductionAxis.size() - 1];
762+
Value newIndice = rewriter.createOrFold<affine::AffineApplyOp>(
763+
loc, (outterParallel + innerParallel), ValueRange{op, ip});
764+
int parallelSize = rdCanonicalizer.getParallelAxis().size();
765+
int readIndiceOffset =
766+
1 + rdCanonicalizer.getParallelAxis()[parallelSize - 1];
767+
sourceReadOp->setOperand(readIndiceOffset, newIndice);
768+
}
765769
}
766770

767771
scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop(
@@ -901,7 +905,7 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop(
901905
loopHelperParam.loopIterArgs = loopState;
902906
moveOperationsToCurrentForBody(b, movingOperation, loopHelperParam);
903907
if (isLastDimReduction)
904-
rectifyParallelIndice(loopHelperParam, b, loc);
908+
rectifyParallelIndice(loopHelperParam, loc);
905909
loopHelperParam.movedOps = &movingOperation;
906910
loopHelperParam.candidateOps = &opQueue;
907911

@@ -2768,7 +2772,7 @@ void GroupOperationFusionImpl::broadcastFromElements(Operation *op,
27682772
DenseElementsAttr::get(dataType, constantOp.getValue()),
27692773
newOperandType);
27702774
if (failed(res))
2771-
llvm::llvm_unreachable_internal("Wrong to create constant op.");
2775+
llvm_unreachable("Wrong to create constant op.");
27722776
removeOpInCurrentGroups(grpIdx, op, res.value().getDefiningOp());
27732777

27742778
} else {

lib/gc/Transforms/TilingVector.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,7 @@ class LoopGeneratorImpl : public ForLoopGenerator {
490490
scf::ForOp reductionAxisGenerateForLoop(OpBuilder &opBuilder,
491491
const size_t reductionIdx,
492492
GenerateLoopHelper &loopHelperParam);
493-
void rectifyParallelIndice(GenerateLoopHelper &loopHelperParam, OpBuilder &b,
494-
Location loc);
493+
void rectifyParallelIndice(GenerateLoopHelper &loopHelperParam, Location loc);
495494
/// reduction operation parallel axis for loop
496495
scf::ForOp parallelAxisGenerateForLoop(OpBuilder &opBuilder,
497496
GenerateLoopHelper &loopHelperParam);

test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ func.func @reduce_fusePostOp_test11(%input: tensor<16x32x64xf32>,
577577
// CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST_0]] {in_bounds = [true]} : tensor<16x32xf32>, vector<16xf32>
578578
// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[READ0]]) -> (vector<16xf32>)
579579
// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[CST]]) -> (vector<16xf32>)
580-
// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[arg4]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32>
580+
// CHECK: %[[APPLY0:.*]] = affine.apply #[[map9]](%[[arg4]], %[[arg6]])
581+
// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[APPLY0]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32>
581582
// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ1]] : vector<16xf32>
582583
// CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[arg9]] : vector<16xf32>
583584
// CHECK: %[[REDUCTION:.*]] = vector.reduction <add>, {{.*}} : vector<16xf32> into f32

0 commit comments

Comments
 (0)