Skip to content

Commit 47687f1

Browse files
author
Xu, Xiaohui1
committed
add lower to vector pass in pipeline
1 parent a1f9988 commit 47687f1

File tree

7 files changed

+186
-188
lines changed

7 files changed

+186
-188
lines changed

include/gc/Analysis/VectorBasedFusionAnalysis.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "gc/Dialect/Linalgx/LinalgxOps.h"
1313
#include "gc/Dialect/Microkernel/MicrokernelOps.h"
1414
#include "gc/Transforms/Passes.h"
15+
#include "gc/Transforms/Utils/VectorUtils.h"
1516
#include "mlir/Dialect/Func/IR/FuncOps.h"
1617
#include "mlir/Dialect/SCF/IR/SCF.h"
1718
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,11 +26,6 @@
2526
namespace mlir {
2627
namespace gc {
2728

28-
mlir::FailureOr<VectorType> getOperationVectorType(Operation *op,
29-
bool isPrevOp = true);
30-
int getNearestVectorStep(const int step);
31-
mlir::FailureOr<VectorType> getOperationMaxVectorType(Operation *op);
32-
3329
/// record hardware information
3430
struct HardWareInfo {
3531
bool favx512f = true;
@@ -136,7 +132,7 @@ class GroupOperationFusion : public VectorFusionBase {
136132
: VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()),
137133
opGroups(strategy.opGroups), groupMaxSteps(strategy.groupMaxSteps),
138134
opGroupIndexMap(strategy.opGroupIndexMap),
139-
opAnchorPos(strategy.opAnchorPos) {};
135+
opAnchorPos(strategy.opAnchorPos){};
140136

141137
GroupOperationFusion(GroupOperationFusion &&strategy)
142138
: VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()),
@@ -145,7 +141,7 @@ class GroupOperationFusion : public VectorFusionBase {
145141
groupBigestRankVectorType(
146142
std::move(strategy.getGroupBiggestRankVectorType())),
147143
opGroupIndexMap(std::move(strategy.opGroupIndexMap)),
148-
opAnchorPos(std::move(strategy.opAnchorPos)) {};
144+
opAnchorPos(std::move(strategy.opAnchorPos)){};
149145

150146
GroupOperationFusion &operator=(GroupOperationFusion &fusion) {
151147
this->getOpGroups() = fusion.getOpGroups();

include/gc/Transforms/Utils/VectorUtils.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,44 @@
1111
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1212
#include "mlir/IR/BuiltinTypes.h"
1313
#include "mlir/IR/TypeUtilities.h"
14+
#include "llvm/ADT/TypeSwitch.h"
1415
#include <limits>
1516
#include <stdint.h>
1617
#include <variant>
1718

1819
namespace mlir {
1920
namespace gc {
21+
/// build a constant operation of index type
22+
Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc,
23+
int64_t x);
24+
25+
/// find the original tensor
26+
Value findOriginalTensor(Value writeTensor, Block *block);
27+
/// get operation read or write tensor
28+
mlir::FailureOr<Value> getOperationOperateTensor(Operation *op);
29+
30+
/// set correct operand for the operation
31+
void setOperationCorrectOperand(
32+
Operation *op, ValueRange iterArgs, DenseMap<Value, int> &operandIdxMap,
33+
DenseMap<Value, Value> &originalOperandLoopArgsMap,
34+
ArrayRef<Value> inductionVars,
35+
DenseMap<Operation *, AffineMap> &opPermuationMap);
36+
37+
/// Get vector type of the operation \param op
38+
/// \param isPrevOp whether the operation is a previous operation, if it is not
39+
/// prev-op, may need to use result vectortype
40+
/// default will return the opeation result type
41+
mlir::FailureOr<VectorType> getOperationVectorType(Operation *op,
42+
bool isPrevOp = true);
43+
44+
/// select nearest even step
45+
int getNearestVectorStep(const int step);
46+
47+
/// get operation vector type
48+
/// \param isPrevOp whether the operation is a previous operation, if it is not
49+
/// prev-op, may need to use result vectortype
50+
/// default will return the opeation result type
51+
mlir::FailureOr<VectorType> getOperationMaxVectorType(Operation *op);
2052
union Float32Bits {
2153
uint32_t u;
2254
float f;

lib/gc/Analysis/VectorBasedFusionAnalysis.cpp

Lines changed: 0 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -277,130 +277,6 @@ bool hasDataDependency(Operation *op1, Operation *op2) {
277277
return res;
278278
}
279279

280-
/// Get vector type of the operation \param op
281-
/// \param isPrevOp whether the operation is a previous operation, if it is not
282-
/// prev-op, may need to use result vectortype
283-
/// default will return the opeation result type
284-
mlir::FailureOr<VectorType> getOperationVectorType(Operation *op,
285-
bool isPrevOp) {
286-
if (not op)
287-
return failure();
288-
289-
auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); };
290-
auto ret =
291-
TypeSwitch<Operation *, mlir::FailureOr<VectorType>>(op)
292-
.Case<vector::TransferWriteOp>(
293-
[&](vector::TransferWriteOp transferWriteOp)
294-
-> mlir::FailureOr<VectorType> {
295-
if (auto retType = dyn_cast<VectorType>(
296-
transferWriteOp.getOperandTypes()[0]))
297-
return retType;
298-
299-
return failure();
300-
})
301-
.Case<vector::TransferReadOp>(
302-
[&](vector::TransferReadOp transferReadOp)
303-
-> mlir::FailureOr<VectorType> {
304-
return transferReadOp.getVectorType();
305-
})
306-
.Case<vector::MultiDimReductionOp>(
307-
[&](vector::MultiDimReductionOp multiReductionOp) {
308-
if (isPrevOp)
309-
return cast<VectorType>(
310-
multiReductionOp->getResultTypes()[0]);
311-
312-
// TODO: may need to add accumulate value vectortype
313-
return cast<VectorType>(multiReductionOp.getSourceVectorType());
314-
})
315-
.Default([&](Operation *op) -> mlir::FailureOr<VectorType> {
316-
if (isPrevOp) {
317-
if (op->getResultTypes().empty())
318-
return failure();
319-
320-
if (auto shapedType =
321-
dyn_cast<VectorType>(op->getResultTypes()[0]))
322-
return shapedType;
323-
324-
return failure();
325-
}
326-
if (op->getOperandTypes().empty())
327-
return failure();
328-
329-
if (auto shapedType =
330-
dyn_cast<VectorType>(op->getOperandTypes()[0]))
331-
return shapedType;
332-
333-
return failure();
334-
});
335-
if (!failed(ret) and isDynamicType(ret.value()))
336-
return failure();
337-
338-
return ret;
339-
}
340-
341-
/// get operation vector type
342-
/// \param isPrevOp whether the operation is a previous operation, if it is not
343-
/// prev-op, may need to use result vectortype
344-
/// default will return the opeation result type
345-
mlir::FailureOr<VectorType> getOperationMaxVectorType(Operation *op) {
346-
if (not op)
347-
return failure();
348-
349-
auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); };
350-
auto ret =
351-
TypeSwitch<Operation *, mlir::FailureOr<VectorType>>(op)
352-
.Case<vector::TransferWriteOp>(
353-
[&](vector::TransferWriteOp transferWriteOp)
354-
-> mlir::FailureOr<VectorType> {
355-
if (auto retType =
356-
cast<VectorType>(transferWriteOp.getOperandTypes()[0]))
357-
return retType;
358-
return failure();
359-
})
360-
.Case<vector::TransferReadOp>(
361-
[&](vector::TransferReadOp transferReadOp)
362-
-> mlir::FailureOr<VectorType> {
363-
return transferReadOp.getVectorType();
364-
})
365-
.Case<vector::MultiDimReductionOp>(
366-
[&](vector::MultiDimReductionOp multiReductionOp) {
367-
return cast<VectorType>(multiReductionOp.getSourceVectorType());
368-
})
369-
.Default([&](Operation *op) -> mlir::FailureOr<VectorType> {
370-
if (op->getResultTypes().empty() and op->getOperandTypes().empty())
371-
return failure();
372-
373-
if (op->getResultTypes().empty())
374-
return cast<VectorType>(op->getOperandTypes()[0]);
375-
376-
if (op->getOperandTypes().empty())
377-
return cast<VectorType>(op->getResultTypes()[0]);
378-
379-
auto opdType = cast<VectorType>(op->getOperandTypes()[0]);
380-
auto retType = cast<VectorType>(op->getResultTypes()[0]);
381-
return opdType.getRank() > retType.getRank() ? opdType : retType;
382-
});
383-
if (!failed(ret) and isDynamicType(ret.value()))
384-
return failure();
385-
386-
return ret;
387-
}
388-
389-
/// select nearest even step
390-
int getNearestVectorStep(const int step) {
391-
if (step <= 0)
392-
llvm_unreachable("Wrong step.");
393-
394-
int nbits = 0, n = step;
395-
while (n) {
396-
n = n >> 1;
397-
nbits++;
398-
}
399-
if (nbits > 6 and (nbits != 7 or step != 64))
400-
llvm_unreachable("wrong nbits appear");
401-
return (1 << (nbits - 1)) == step ? step : (1 << nbits);
402-
}
403-
404280
/// Get the operation which is not a read-write in current queue
405281
/// \param [in, out] op
406282
Operation *getNotReadWriteOperaiton(std::queue<Operation *> &tmpQ) {

lib/gc/Transforms/CPUPhysicalRegisterPass.cpp

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,6 @@ static inline void moveOpBeginingOfBlock(Operation *op) {
8585
op->moveBefore(&block->front());
8686
}
8787

88-
/// find the original tensor
89-
Value findOriginalTensor(Value writeTensor, Block *block) {
90-
while (auto wtOp = dyn_cast_or_null<vector::TransferWriteOp>(
91-
writeTensor.getDefiningOp())) {
92-
if (block != writeTensor.getDefiningOp()->getBlock())
93-
break;
94-
95-
writeTensor = wtOp->getOperand(1);
96-
}
97-
return writeTensor;
98-
}
99-
10088
/// whether operation is a not support operation
10189
bool isNotSupportOperation(Operation *op) {
10290
return isa<vector::MaskOp, vector::ConstantMaskOp, vector::MaskedLoadOp,
@@ -517,13 +505,6 @@ void classifyAccRelatedOps(std::queue<Operation *> &accRelatedOps,
517505
}
518506
}
519507

520-
Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc,
521-
int64_t x) {
522-
return opBuilder.create<arith::ConstantOp>(
523-
loc, opBuilder.getIndexType(),
524-
opBuilder.getIntegerAttr(opBuilder.getIndexType(), x));
525-
}
526-
527508
void ForLoopGenerator::moveOperationsToCurrentForBody(
528509
const OpBuilder &b, std::queue<Operation *> &opQueue,
529510
GenerateLoopHelper &loopHelperParam) {
@@ -2424,25 +2405,6 @@ void ForLoopGenerator::rewriteOperationAsVectorize(
24242405
}
24252406
}
24262407

2427-
mlir::FailureOr<Value> getOperationOperateTensor(Operation *op) {
2428-
return TypeSwitch<Operation *, mlir::FailureOr<Value>>(op)
2429-
.Case<vector::TransferWriteOp>(
2430-
[&](vector::TransferWriteOp transferWriteOp) {
2431-
// find original tensor.empty operation
2432-
auto writeTensor = transferWriteOp->getOperand(1);
2433-
writeTensor =
2434-
findOriginalTensor(writeTensor, transferWriteOp->getBlock());
2435-
return writeTensor;
2436-
})
2437-
.Case<vector::TransferReadOp>([&](vector::TransferReadOp transferReadOp) {
2438-
return transferReadOp->getOperand(0);
2439-
})
2440-
.Default([&](Operation *op) {
2441-
LDBG("Try to get not DPS operation inits: " << *op << "\n");
2442-
return failure();
2443-
});
2444-
}
2445-
24462408
void GroupOperationFusionImpl::removeOpInCurrentGroups(size_t grpIdx,
24472409
Operation *op,
24482410
Operation *replacedOp) {

lib/gc/Transforms/Pipeline.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ void populateCPURuntimePasses(mlir::OpPassManager &pm) {
150150
}
151151

152152
void populateLoweringToLLVMPasses(mlir::OpPassManager &pm) {
153+
pm.addPass(createConvertVectorToSCFPass());
154+
pm.addPass(createConvertVectorToLLVMPass());
153155
pm.addPass(createLowerAffinePass());
154156
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
155157
pm.addPass(createConvertSCFToCFPass());

lib/gc/Transforms/TilingVector.hpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "gc/Analysis//VectorBasedFusionAnalysis.h"
1212
#include "gc/Analysis/TargetDescriptionAnalysis.h"
13-
#include "gc/Transforms/Utils/VectorUtils.h"
1413
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1514
#include "mlir/Dialect/Arith/IR/Arith.h"
1615
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,7 +20,6 @@
2120
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
2221
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2322
#include "mlir/Dialect/Vector/Transforms/Passes.h"
24-
#include "mlir/ExecutionEngine/Float16bits.h"
2523
#include "mlir/IR/Visitors.h"
2624
#include "mlir/Transforms/CSE.h"
2725
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -30,23 +28,6 @@
3028
namespace mlir {
3129
namespace gc {
3230

33-
//===----------------------------------------------------------------------===//
34-
// helper function
35-
//===----------------------------------------------------------------------===//
36-
37-
/// build a constant operation of index type
38-
Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x);
39-
40-
/// get operation read or write tensor
41-
mlir::FailureOr<Value> getOperationOperateTensor(Operation *op);
42-
43-
/// set correct operand for the operation
44-
void setOperationCorrectOperand(
45-
Operation *op, ValueRange iterArgs, DenseMap<Value, int> &operandIdxMap,
46-
DenseMap<Value, Value> &originalOperandLoopArgsMap,
47-
ArrayRef<Value> inductionVars,
48-
DenseMap<Operation *, AffineMap> &opPermuationMap);
49-
5031
/// get fusion kind
5132
/// Has two kind:
5233
/// 1. OperationGroup:

0 commit comments

Comments
 (0)