Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
a99e055030f0da61651e808cbb208bb39594cdc0
faf5d747f174cc9d714839f0d3bce1a783eac2ac
2 changes: 1 addition & 1 deletion lib/TPP/Runner/MLIRBench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ LogicalResult MLIRBench::printResult(Operation *kernelCall) {

if (isa<TensorType>(result.getType())) {
result =
builder.create<bufferization::ToMemrefOp>(unkLoc, memrefType, result);
builder.create<bufferization::ToBufferOp>(unkLoc, memrefType, result);
}

auto outBuf = builder.create<memref::AllocOp>(unkLoc, memrefType);
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ struct BrgemmLinalgTiling
RewritePatternSet patterns(&getContext());
populateBrgemmLinalgTilingPatterns(patterns, options);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.setStrictness(GreedyRewriteStrictness::ExistingOps);

(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/HoistVectorTransfers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ struct HoistVectorTransfers
RewritePatternSet patterns(&getContext());
populateHoistVectorTransferPatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
};
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/SplitReductionDim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct SplitReductionDim
RewritePatternSet patterns(ctx);
patterns.add<SplitContractionReduction>(ctx, options);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
};
Expand Down
14 changes: 7 additions & 7 deletions lib/TPP/Transforms/VectorContractToAMX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ struct VectorContractToAMXPattern
int64_t M = accType.getDimSize(0);
int64_t N = accType.getDimSize(1);

auto accSubview = accDefiningOp.getSource();
auto accSubview = accDefiningOp.getBase();
Location loc = op.getLoc();
scf::ForOp insertAt =
getOutermostLoopWithIterargAccumulator(ctx.innerForOp, acc);
Expand Down Expand Up @@ -472,14 +472,14 @@ struct VectorContractToAMXPattern
// Update index of LHS matrix subview for batch dimension if corresponding
// loop is needed.
if (iv)
mapping.map(lhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
mapping.map(lhsDefiningOp.getBase().getDefiningOp()->getOperand(1),
iv);
// Update index of LHS matrix subview for K dimension.
mapping.map(
lhsDefiningOp.getSource().getDefiningOp()->getOperand(iv ? 3 : 1),
lhsDefiningOp.getBase().getDefiningOp()->getOperand(iv ? 3 : 1),
innerIv);
auto lhsClone = innerBuilder.clone(
*lhsDefiningOp.getSource().getDefiningOp(), mapping);
*lhsDefiningOp.getBase().getDefiningOp(), mapping);
// Load matrix A tile
SmallVector<Value, 4> aLoadTiles =
createTileLoads(innerBuilder, loc, amxInputTilesOf16x32xBf16Ty,
Expand All @@ -489,14 +489,14 @@ struct VectorContractToAMXPattern
// Update index of LHS matrix subview for batch dimension if corresponding
// loop is needed.
if (iv)
rhsMapping.map(rhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
rhsMapping.map(rhsDefiningOp.getBase().getDefiningOp()->getOperand(1),
iv);
// Update index of LHS matrix subview for K dimension.
rhsMapping.map(
rhsDefiningOp.getSource().getDefiningOp()->getOperand(iv ? 2 : 1),
rhsDefiningOp.getBase().getDefiningOp()->getOperand(iv ? 2 : 1),
innerIv);
auto rhsClone = innerBuilder.clone(
*rhsDefiningOp.getSource().getDefiningOp(), rhsMapping);
*rhsDefiningOp.getBase().getDefiningOp(), rhsMapping);
// Load matrix B tile, vnni factor and N tile size will be collapsed as
// effective tilse size.
SmallVector<Value, 4> bLoadTiles =
Expand Down
16 changes: 8 additions & 8 deletions lib/TPP/Transforms/VectorContractToBF16DotProduct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,13 @@ struct BF16DotProductOp : OpRewritePattern<vector::ContractionOp> {
Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
IRMapping mapping;
mapping.map(
vectorReadOpLhs.getSource().getDefiningOp()->getOperand(1),
vectorReadOpLhs.getBase().getDefiningOp()->getOperand(1),
ivNewReductionForOp);
mapping.map(
vectorReadOpLhs.getSource().getDefiningOp()->getOperand(3),
vectorReadOpLhs.getBase().getDefiningOp()->getOperand(3),
ivNewKForOp);
auto lhsClone = rewriterNewKForOp.clone(
*vectorReadOpLhs.getSource().getDefiningOp(), mapping);
*vectorReadOpLhs.getBase().getDefiningOp(), mapping);

// Memory access for A Matrix into <32xbf16>
llvm::SmallVector<Value, 8> vectorA;
Expand Down Expand Up @@ -440,13 +440,13 @@ struct BF16DotProductOp : OpRewritePattern<vector::ContractionOp> {

IRMapping rhsMapping;
rhsMapping.map(
vectorReadOpRhs.getSource().getDefiningOp()->getOperand(1),
ivNewReductionForOp);
vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1),
ivNewReductionForOp);
rhsMapping.map(
vectorReadOpRhs.getSource().getDefiningOp()->getOperand(2),
vectorReadOpRhs.getBase().getDefiningOp()->getOperand(2),
ivNewKForOp);
auto rhsClone = rewriterNewKForOp.clone(
*vectorReadOpRhs.getSource().getDefiningOp(), rhsMapping);
*vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);

// Memory access for B Matrix into <32xbf16>
llvm::SmallVector<Value, 8> vectorB;
Expand Down Expand Up @@ -516,7 +516,7 @@ struct BF16DotProduct : public impl::BF16DotProductBase<BF16DotProduct> {
RewritePatternSet patterns(&getContext());
populateBF16DotProductPatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
};
Expand Down
14 changes: 7 additions & 7 deletions lib/TPP/Transforms/VectorContractToFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ struct VectorContractToFMAPattern
if (K != 1)
return failure();

auto accSubview = accDefiningOp.getSource();
auto accSubview = accDefiningOp.getBase();
Location loc = op.getLoc();

// Create M different <1xN> subviews.
Expand Down Expand Up @@ -295,13 +295,13 @@ struct VectorContractToFMAPattern
ValueRange innerIterArgs) {
IRMapping mapping;
mapping.map(
lhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
lhsDefiningOp.getBase().getDefiningOp()->getOperand(1),
iv);
mapping.map(
lhsDefiningOp.getSource().getDefiningOp()->getOperand(3),
lhsDefiningOp.getBase().getDefiningOp()->getOperand(3),
innerIv);
auto lhsClone = innerBuilder.clone(
*lhsDefiningOp.getSource().getDefiningOp(), mapping);
*lhsDefiningOp.getBase().getDefiningOp(), mapping);

// Load and broadcast individual elements
SmallVector<Value, 4> broadcasts;
Expand All @@ -319,13 +319,13 @@ struct VectorContractToFMAPattern

IRMapping rhsMapping;
rhsMapping.map(
rhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
rhsDefiningOp.getBase().getDefiningOp()->getOperand(1),
iv);
rhsMapping.map(
rhsDefiningOp.getSource().getDefiningOp()->getOperand(2),
rhsDefiningOp.getBase().getDefiningOp()->getOperand(2),
innerIv);
auto rhsClone = innerBuilder.clone(
*rhsDefiningOp.getSource().getDefiningOp(), rhsMapping);
*rhsDefiningOp.getBase().getDefiningOp(), rhsMapping);
auto rowVec = innerBuilder.create<vector::LoadOp>(
loc, VectorType::get({N}, elementType),
rhsClone->getResult(0), ValueRange{c0, c0, c0});
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/VectorContractToOuterproduct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ struct VectorContractToOuterproductPattern
nestedBuilder.getContext());
}

Value lhsTensor = lhsDefiningOp.getSource();
Value rhsTensor = rhsDefiningOp.getSource();
Value lhsTensor = lhsDefiningOp.getBase();
Value rhsTensor = rhsDefiningOp.getBase();
// Read vector slices using TransferReadOp
auto lhsSlice = nestedBuilder.create<vector::TransferReadOp>(
nestedLoc, VectorType::get({M}, lhsType.getElementType()),
Expand Down