Skip to content

Commit 60226eb

Browse files
authored
add aligning transform into blocking pass (#716)
add aligning transform into blocking pass
1 parent e56462c commit 60226eb

29 files changed

+1313
-377
lines changed

include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class XeGPUTypeConverter;
3838

3939
/// Populate the given list with patterns rewrite XeTile Ops
4040
void populateXeTileToXeGPUConversionPatterns(XeGPUTypeConverter &converter,
41-
mlir::RewritePatternSet &patterns);
41+
mlir::RewritePatternSet &patterns,
42+
imex::TileUsageAnalysis &analysis);
4243

4344
/// Create a pass to convert the XeTile dialect to the XeGPU dialect.
4445
std::unique_ptr<mlir::OperationPass<mlir::gpu::GPUModuleOp>>

include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ namespace imex {
3838

3939
class XeGPUTypeConverter : public imex::XeTypeConverter {
4040
public:
41-
XeGPUTypeConverter(mlir::MLIRContext &context,
42-
TileUsageAnalysis *analysis = nullptr);
41+
XeGPUTypeConverter(mlir::MLIRContext &context);
4342

4443
std::optional<mlir::LogicalResult>
4544
convertTileType(xetile::TileType tileTy,
@@ -112,13 +111,15 @@ class XeGPUOneToNPatterRewriter : public mlir::PatternRewriter,
112111
};
113112

114113
template <typename SourceOp>
115-
class SgXeTileToXeGPUConversion : public XeConversionPattern {
114+
class SgXeTileToXeGPUConversion
115+
: public XeConversionPattern<TileUsageAnalysis> {
116116
public:
117117
SgXeTileToXeGPUConversion(mlir::MLIRContext *context,
118118
XeGPUTypeConverter &typeConverter,
119+
TileUsageAnalysis &analysis,
119120
mlir::PatternBenefit benefit = 1)
120-
: XeConversionPattern(typeConverter, SourceOp::getOperationName(),
121-
benefit, context) {}
121+
: XeConversionPattern(typeConverter, analysis,
122+
SourceOp::getOperationName(), benefit, context) {}
122123

123124
using RangeT = llvm::ArrayRef<mlir::ValueRange>;
124125
using OpAdaptor = typename SourceOp::template GenericAdaptor<RangeT>;

include/imex/Dialect/XeTile/Transforms/Blocking.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- Blocking.h ----- Blocking Pass -------*- C++ -*-===//
1+
//===- Blocking.h ----------- Blocking Pass ---------*- C++ -*------------===//
22
//
33
// Copyright 2024 Intel Corporation
44
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -32,16 +32,17 @@
3232

3333
namespace imex {
3434

35-
template <typename SourceOp>
36-
class XeTileConversion : public imex::XeConversionPattern {
35+
template <typename SourceOp, typename AnalysisT>
36+
class XeTileConversion : public imex::XeConversionPattern<AnalysisT> {
3737
public:
3838
using OpAdaptor = typename SourceOp::Adaptor;
3939
using OpPatternRewriter = typename mlir::PatternRewriter;
4040

4141
XeTileConversion(mlir::MLIRContext *context, XeTypeConverter &typeConverter,
42-
mlir::PatternBenefit benefit = 1)
43-
: XeConversionPattern(typeConverter, SourceOp::getOperationName(),
44-
benefit, context) {}
42+
AnalysisT &analysis, mlir::PatternBenefit benefit = 1)
43+
: XeConversionPattern<AnalysisT>(typeConverter, analysis,
44+
SourceOp::getOperationName(), benefit,
45+
context) {}
4546

4647
mlir::LogicalResult
4748
matchAndRewrite(mlir::Operation *op,
@@ -70,7 +71,7 @@ class XeTileConversion : public imex::XeConversionPattern {
7071
auto unpackTy = mlir::VectorType::get(unpackShape, srcTy.getElementType());
7172
return rewriter.create<xetile::TileUnpackOp>(
7273
src.getLoc(), unpackTy, src,
73-
mlir::DenseI64ArrayAttr::get(getContext(), innerBlocks));
74+
mlir::DenseI64ArrayAttr::get(src.getContext(), innerBlocks));
7475
}
7576

7677
xetile::TilePackOp addPackOp(mlir::Value src,
@@ -86,7 +87,7 @@ class XeTileConversion : public imex::XeConversionPattern {
8687
auto packTy = mlir::VectorType::get(packShape, srcTy.getElementType());
8788
auto packOp = rewriter.create<xetile::TilePackOp>(
8889
src.getLoc(), packTy, src,
89-
mlir::DenseI64ArrayAttr::get(getContext(), targetBlkSizes));
90+
mlir::DenseI64ArrayAttr::get(src.getContext(), targetBlkSizes));
9091
return packOp;
9192
}
9293

@@ -98,16 +99,16 @@ class XeTileConversion : public imex::XeConversionPattern {
9899
}
99100
};
100101

101-
template <template <typename> class TraitType>
102-
class XeTileTraitConversion : public imex::XeConversionPattern {
102+
template <template <typename> class TraitType, typename AnalysisT>
103+
class XeTileTraitConversion : public imex::XeConversionPattern<AnalysisT> {
103104
public:
104105
XeTileTraitConversion(mlir::MLIRContext *context,
105-
XeTypeConverter &typeConverter,
106+
XeTypeConverter &typeConverter, AnalysisT &analysis,
106107
mlir::PatternBenefit benefit = 1)
107-
: XeConversionPattern(typeConverter, mlir::Pattern::MatchTraitOpTypeTag(),
108-
mlir::TypeID::get<TraitType>(), benefit, context) {}
108+
: XeConversionPattern<AnalysisT>(
109+
typeConverter, analysis, mlir::Pattern::MatchTraitOpTypeTag(),
110+
mlir::TypeID::get<TraitType>(), benefit, context) {}
109111
};
110-
111112
} // namespace imex
112113

113114
#endif

include/imex/Dialect/XeTile/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ std::unique_ptr<mlir::Pass> createXeTileInitDuplicatePass();
3939

4040
std::unique_ptr<mlir::Pass>
4141
createXeTileBlockingPass(const std::string &device = "pvc");
42+
std::unique_ptr<mlir::Pass> createXeTileBlockAligningPass();
4243

4344
///
4445
void populateXeTileInitDuplicatePatterns(imex::XeTypeConverter &converter,

include/imex/Dialect/XeTile/Transforms/Passes.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,25 @@ def XeTileBlocking : Pass<"xetile-blocking", "::mlir::gpu::GPUModuleOp">{
6060
];
6161
}
6262

63+
64+
// TODO: [block-aligning] remove the following code when upstreaming the pass.
65+
// The pass is not supposed to be exposed to users. Temporary keep in case we
66+
// need debug it for down stream development.
67+
def XeTileBlockAligning: Pass <"xetile-block-aligning", "::mlir::gpu::GPUModuleOp"> {
68+
let summary = "optimize the performance for mma.";
69+
70+
let description = [{
71+
This transform is to optimize performance by aligning the block size among operators
72+
to reduce in-register data movements. Currently, it mainly focues on the alignment
73+
between load and MMA operators.
74+
}];
75+
76+
let constructor = "imex::createXeTileBlockAligningPass()";
77+
let dependentDialects = ["imex::xetile::XeTileDialect",
78+
"mlir::arith::ArithDialect",
79+
"mlir::gpu::GPUDialect",
80+
"mlir::memref::MemRefDialect",
81+
"mlir::vector::VectorDialect"];
82+
}
83+
6384
#endif // _XeTile_PASSES_TD_INCLUDED_

include/imex/Utils/XeCommon.h

Lines changed: 132 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,6 @@ class TileUsageAnalysis {
101101
}); // walk on LoadTileOp
102102
};
103103

104-
uint getUsage(imex::xetile::InitTileOp op) {
105-
if (Usage.count(op))
106-
return Usage[op];
107-
return UsageType::None;
108-
}
109-
110-
uint getUsage(imex::xetile::LoadTileOp op) {
111-
if (Usage.count(op))
112-
return Usage[op];
113-
return UsageType::None;
114-
}
115-
116104
bool isForDPASA(imex::xetile::LoadTileOp op) {
117105
if (Usage.count(op)) {
118106
return Usage[op] & UsageType::DPAS_A;
@@ -200,14 +188,93 @@ class TileUsageAnalysis {
200188
llvm::DenseMap<mlir::Operation *, uint> Usage;
201189
};
202190

191+
// This analysis is used to propagate the inner block size of an operator
192+
// to its uses or users. Current implementation is to propagate the MMA
193+
// size used by an MMA operator to the definition (InitTileOp) for its operands.
194+
// TODO: This analysis can be extended to propagate the block size for other ops
195+
// such that it can be used as a general analysis for other block size
196+
// optimizations.
197+
class PropagateAnalysis {
198+
private:
199+
llvm::DenseMap<mlir::Operation *, mlir::DenseI64ArrayAttr> OpAttrMap;
200+
201+
public:
202+
PropagateAnalysis(mlir::Operation *op) {
203+
op->walk<mlir::WalkOrder::PostOrder>([&](xetile::TileMMAOp op) {
204+
mlir::Operation *operation = op.getOperation();
205+
for (auto value : operation->getOperands()) {
206+
auto packOp = value.getDefiningOp<xetile::TilePackOp>();
207+
if (packOp) {
208+
auto blkSZ = packOp.getInnerBlocksAttr();
209+
propagate(value, blkSZ);
210+
}
211+
}
212+
});
213+
}
214+
215+
bool maybeUpdated(mlir::Operation *op) { return OpAttrMap.count(op); }
216+
217+
mlir::DenseI64ArrayAttr getValue(mlir::Operation *op) {
218+
if (OpAttrMap.count(op))
219+
return OpAttrMap[op];
220+
return {};
221+
}
222+
223+
private:
224+
mlir::Operation *getDefineOrParentOp(mlir::Value value) {
225+
if (llvm::isa<mlir::OpResult>(value))
226+
return value.getDefiningOp();
227+
if (auto arg = llvm::dyn_cast_or_null<mlir::BlockArgument>(value))
228+
return arg.getOwner()->getParentOp();
229+
return nullptr;
230+
};
231+
232+
mlir::Value getOperandForArg(mlir::scf::ForOp &forOp, mlir::Value &value) {
233+
auto arg = llvm::dyn_cast<mlir::BlockArgument>(value);
234+
if (arg && arg.getArgNumber() >= forOp.getNumInductionVars()) {
235+
auto &iterOperand = *forOp.getTiedLoopInit(arg);
236+
auto numCtrlOperands = forOp.getNumControlOperands();
237+
auto operandIdx = iterOperand.getOperandNumber();
238+
return forOp.getInitArgs()[operandIdx - numCtrlOperands];
239+
}
240+
return mlir::Value();
241+
};
242+
243+
void propagate(mlir::Value start, mlir::DenseI64ArrayAttr attr) {
244+
llvm::SmallVector<mlir::Value> queue;
245+
if (bool(start))
246+
queue.push_back(start);
247+
248+
while (queue.size()) {
249+
auto value = queue.pop_back_val();
250+
if (!bool(value))
251+
continue;
252+
253+
auto *op = getDefineOrParentOp(value);
254+
255+
// stop when meet a function.
256+
if (!op || llvm::isa<mlir::FunctionOpInterface>(op))
257+
return;
258+
259+
OpAttrMap[op] = attr;
260+
261+
if (auto forOp = llvm::dyn_cast<mlir::scf::ForOp>(op)) {
262+
auto opr = getOperandForArg(forOp, value);
263+
if (bool(opr))
264+
queue.push_back(opr);
265+
} else if (op->getNumOperands() == 1) {
266+
queue.push_back(op->getOperand(0));
267+
}
268+
}
269+
}
270+
};
271+
203272
class XeTypeConverter : public mlir::OneToNTypeConverter {
204273
public:
205-
friend class XeConversionPattern;
274+
// friend class XeConversionPattern;
206275
using mlir::OneToNTypeConverter::convertType;
207276

208-
XeTypeConverter(mlir::MLIRContext &context,
209-
TileUsageAnalysis *analysis = nullptr)
210-
: context(context), usageAnalysis(analysis) {
277+
XeTypeConverter(mlir::MLIRContext &context) : context(context) {
211278
addConversion([&](xetile::TileType tileTy,
212279
llvm::SmallVectorImpl<mlir::Type> &resultTypes)
213280
-> std::optional<mlir::LogicalResult> {
@@ -235,72 +302,95 @@ class XeTypeConverter : public mlir::OneToNTypeConverter {
235302

236303
private:
237304
mlir::MLIRContext &context;
238-
239-
protected:
240-
TileUsageAnalysis *usageAnalysis;
241305
};
242306

243307
// A simple mlir::RewritePattern wrapper with methods for accessing UsageType
308+
template <typename AnalysisT>
244309
class XeConversionPattern : public mlir::RewritePattern {
245310
public:
246311
using mlir::RewritePattern::RewritePattern;
247312

248313
template <typename... Args>
249-
XeConversionPattern(imex::XeTypeConverter &typeConverter, Args &&...args)
314+
XeConversionPattern(imex::XeTypeConverter &typeConverter, AnalysisT &analysis,
315+
Args &&...args)
250316
: mlir::RewritePattern(std::forward<Args>(args)...),
251-
typeConverter(typeConverter) {}
317+
typeConverter(typeConverter), analysis(analysis) {}
252318

253319
virtual mlir::LogicalResult
254320
matchAndRewrite(mlir::Operation *op,
255321
mlir::PatternRewriter &rewriter) const override {
256322
llvm_unreachable("must override matchAndRewrite or a rewrite method");
257323
};
258324

325+
imex::XeTypeConverter &getTypeConverter() const { return typeConverter; }
326+
327+
template <typename ConverterTy>
328+
std::enable_if_t<std::is_base_of<mlir::TypeConverter, ConverterTy>::value,
329+
ConverterTy &>
330+
getTypeConverter() const {
331+
return static_cast<ConverterTy &>(typeConverter);
332+
}
333+
334+
protected:
335+
imex::XeTypeConverter &typeConverter;
336+
AnalysisT &analysis;
337+
338+
template <typename = typename std::enable_if<
339+
std::is_same_v<AnalysisT, PropagateAnalysis>>>
340+
mlir::DenseI64ArrayAttr getValue(mlir::Operation *op) const {
341+
if (op)
342+
return llvm::cast<PropagateAnalysis>(analysis).getValue(op);
343+
return {};
344+
}
345+
346+
template <typename = typename std::enable_if<
347+
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
259348
bool isForDPASA(imex::xetile::LoadTileOp op) const {
260-
return typeConverter.usageAnalysis->isForDPASA(op);
349+
return llvm::cast<TileUsageAnalysis>(analysis).isForDPASA(op);
261350
}
262351

352+
template <typename = typename std::enable_if<
353+
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
263354
bool isForDPASB(imex::xetile::LoadTileOp op) const {
264-
return typeConverter.usageAnalysis->isForDPASB(op);
355+
return llvm::cast<TileUsageAnalysis>(analysis).isForDPASB(op);
265356
}
266357

358+
template <typename = typename std::enable_if<
359+
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
267360
bool isForDPASC(imex::xetile::LoadTileOp op) const {
268-
return typeConverter.usageAnalysis->isForDPASC(op);
361+
return llvm::cast<TileUsageAnalysis>(analysis).isForDPASC(op);
269362
}
270363

364+
template <typename = typename std::enable_if<
365+
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
271366
bool isForLoad(imex::xetile::InitTileOp op) const {
272-
return typeConverter.usageAnalysis->isForLoad(op);
367+
return llvm::cast<TileUsageAnalysis>(analysis).isForLoad(op);
273368
}
274369

370+
template <typename = typename std::enable_if<
371+
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
275372
bool isForStore(imex::xetile::InitTileOp op) const {
276-
return typeConverter.usageAnalysis->isForStore(op);
373+
return llvm::cast<TileUsageAnalysis>(analysis).isForStore(op);
277374
}
278375

376+
template <typename = typename std::enable_if<
377+
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
279378
bool isForPrefetch(imex::xetile::InitTileOp op) const {
280-
return typeConverter.usageAnalysis->isForPrefetch(op);
379+
return llvm::cast<TileUsageAnalysis>(analysis).isForPrefetch(op);
281380
}
282381

382+
template <typename = typename std::enable_if<
383+
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
283384
bool isForLoadAndPrefetch(imex::xetile::InitTileOp op) const {
284-
return typeConverter.usageAnalysis->isForLoadAndPrefetch(op);
385+
return llvm::cast<TileUsageAnalysis>(analysis).isForLoadAndPrefetch(op);
285386
}
286387

388+
template <typename = typename std::enable_if<
389+
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
287390
bool isForLoadAndStore(imex::xetile::InitTileOp op) const {
288-
return typeConverter.usageAnalysis->isForLoadAndStore(op);
289-
}
290-
291-
imex::XeTypeConverter &getTypeConverter() const { return typeConverter; }
292-
293-
template <typename ConverterTy>
294-
std::enable_if_t<std::is_base_of<mlir::TypeConverter, ConverterTy>::value,
295-
ConverterTy &>
296-
getTypeConverter() const {
297-
return static_cast<ConverterTy &>(typeConverter);
391+
return llvm::cast<TileUsageAnalysis>(analysis).isForLoadAndStore(op);
298392
}
299-
300-
protected:
301-
imex::XeTypeConverter &typeConverter;
302393
};
303-
304394
} // namespace imex
305395

306396
#endif

lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ bool isLegalArithOp(mlir::Operation *op) {
8181
}
8282

8383
void populateArithOpConversionPatterns(imex::XeGPUTypeConverter &converter,
84-
mlir::RewritePatternSet &patterns) {
85-
patterns.add<SgArithConstantOpPattern>(patterns.getContext(), converter);
84+
mlir::RewritePatternSet &patterns,
85+
TileUsageAnalysis &analysis) {
86+
patterns.add<SgArithConstantOpPattern>(patterns.getContext(), converter,
87+
analysis);
8688
}
8789

8890
} // namespace imex

0 commit comments

Comments
 (0)