Skip to content

Commit 94f5522

Browse files
Reworked
1 parent 7d157e7 commit 94f5522

File tree

7 files changed

+322
-104
lines changed

7 files changed

+322
-104
lines changed

include/gc/Transforms/Passes.td

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,15 @@ def GpuTilingAndFusion : Pass<"gpu-tiling", "func::FuncOp"> {
120120
The tiles calculation is based on the Execution Unit cache size and the number of threads per EU.
121121
}];
122122
let options = [
123-
Option<"euMem", "eu-mem", "size_t",
123+
Option<"numThreads", "num-threads", "size_t",
124+
/*default=*/"8",
125+
"Number of threads per Execution Unit.">,
126+
Option<"cacheSize", "cache-size", "size_t",
124127
/*default=*/"131072",
125128
"Execution Unit cache size.">,
126-
Option<"euThreads", "eu-threads", "size_t",
127-
/*default=*/"8",
128-
"Number of threads per EU.">
129+
Option<"vectorWidth", "vector-width", "size_t",
130+
/*default=*/"512",
131+
"The maximum width of EU's vector registers.">
129132
];
130133
}
131134

@@ -136,7 +139,7 @@ def GpuLoopTiling : Pass<"gpu-loop-tiling", "func::FuncOp"> {
136139
Each tile of the outer loop is divided by the number of threads per EU.
137140
}];
138141
let options = [
139-
Option<"euThreads", "eu-threads", "size_t",
142+
Option<"numThreads", "num-threads", "size_t",
140143
/*default=*/"8",
141144
"Number of threads per Execution Unit.">
142145
];

lib/gc/Transforms/GPU/GpuLoopTiling.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include "gc/Utils/Log.h"
2323

2424
using namespace mlir;
25-
// using namespace mlir::gc::gpu;
25+
using namespace mlir::gc;
2626

2727
namespace mlir::gc {
2828
#define GEN_PASS_DECL_GPULOOPTILING
@@ -41,34 +41,29 @@ struct GpuLoopTiling final : GpuPass<GpuLoopTiling>,
4141

4242
void runOnOperation() override {
4343
IRRewriter rewriter(&getContext());
44-
auto euThreads = static_cast<double>(getEuThreads(rewriter));
44+
auto numThreads = getNumThreads(rewriter);
4545
getOperation().walk<WalkOrder::PreOrder>([&](scf::ParallelOp loop) {
4646
if (!loop->getParentOfType<scf::ParallelOp>()) {
47-
tile(loop, euThreads);
47+
SmallVector<int64_t> tiles;
48+
auto steps = loop.getStep();
49+
tiles.reserve(steps.size());
50+
51+
for (auto step : steps) {
52+
if (auto v = getConstIdxValue(step)) {
53+
tiles.push_back(v);
54+
} else {
55+
tiles.push_back(32);
56+
}
57+
}
58+
59+
adjustTiles(numThreads, tiles);
60+
tileParallelLoop(loop, tiles, false);
4861
}
4962
return WalkResult::skip();
5063
});
5164
if (failed(simplifyRegions(rewriter, getOperation()->getRegions()))) {
5265
gcLogD("Failed to simplify regions");
5366
}
5467
}
55-
56-
private:
57-
static void tile(scf::ParallelOp loop, double euThreads) {
58-
SmallVector<int64_t> tileSizes;
59-
auto steps = loop.getStep();
60-
tileSizes.reserve(steps.size());
61-
62-
for (auto step : steps) {
63-
if (auto v = getConstIdxValue(step)) {
64-
tileSizes.push_back(static_cast<int64_t>(
65-
std::ceil(static_cast<double>(v) / euThreads)));
66-
} else {
67-
tileSizes.push_back(32);
68-
}
69-
}
70-
71-
tileParallelLoop(loop, tileSizes, false);
72-
}
7368
};
7469
} // namespace

lib/gc/Transforms/GPU/GpuTilingAndFusion.cpp

Lines changed: 102 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include "gc/Utils/Log.h"
2525

2626
using namespace mlir;
27-
// using namespace mlir::gc::gpu;
27+
using namespace mlir::gc;
2828

2929
namespace mlir::gc {
3030
#define GEN_PASS_DECL_GPUTILINGANDFUSION
@@ -39,33 +39,34 @@ struct GpuTilingAndFusion final
3939
gc::impl::GpuTilingAndFusionBase<GpuTilingAndFusion> {
4040
friend struct GpuPass;
4141
explicit GpuTilingAndFusion()
42-
: GpuTilingAndFusion(gc::GpuTilingAndFusionOptions{}) {}
43-
explicit GpuTilingAndFusion(const gc::GpuTilingAndFusionOptions &opts)
42+
: GpuTilingAndFusion(GpuTilingAndFusionOptions{}) {}
43+
explicit GpuTilingAndFusion(const GpuTilingAndFusionOptions &opts)
4444
: GpuPass(), GpuTilingAndFusionBase(opts) {}
4545

4646
void runOnOperation() override {
4747
IRRewriter rewriter(&getContext());
4848
scf::SCFTileAndFuseOptions opts;
49+
opts.setFusionControlFn(
50+
[&](tensor::ExtractSliceOp, OpResult originalProducer, bool)
51+
-> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
52+
Operation *op = originalProducer.getOwner();
53+
if (!op) {
54+
return std::nullopt;
55+
}
56+
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
57+
if (!linalgOp.hasOnlyProjectedPermutations()) {
58+
return std::nullopt;
59+
}
60+
}
61+
return scf::SCFTileAndFuseOptions::ControlFnResult{};
62+
});
4963
opts.tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
5064
// The outer loop is converted to a GPU kernel and the tile sizes are mapped
5165
// to the grid sizes.
5266
opts.tilingOptions.setTileSizeComputationFunction(
53-
// The tile sizes calculation is based on the following equation:
54-
// n * TS0 * TS1 * ... * TSn = euMem
55-
// where:
56-
// n - an average number of bytes, processed by each iteration
57-
// TS0, TS1, ... TSn - the tile sizes for each loop correspondingly
58-
// euMem - the physical memory (cache) size of the GPU execution unit
59-
//
60-
// To calculate the tile size TS, we need to divide the total loop size
61-
// S by the ratio r:
62-
//
63-
// n * (S0/r0) * (S1/r1) * ... * (Sn/rn) = euMem
64-
// r0 * r1 * ... * rn = (n * S0 * S1 * ... * Sn) / euMem
65-
// If all sizes are equal, then S0 = ... = Sn = S, r0 = ... = rn = r:
66-
// r^n = (n * S^n) / euMem
67-
// r = (n * S^n / euMem)^(1/n)
68-
[euMem = getEuMem(rewriter), euThreads = getEuThreads(rewriter)](
67+
[numThreads = getNumThreads(rewriter),
68+
cacheSize = getCacheSize(rewriter),
69+
vectorWidth = getVectorWidth(rewriter)](
6970
OpBuilder &builder, Operation *op) -> SmallVector<OpFoldResult> {
7071
auto ti = dyn_cast<TilingInterface>(op);
7172
if (!ti) {
@@ -76,44 +77,47 @@ struct GpuTilingAndFusion final
7677
auto itDomains = ti.getIterationDomain(builder);
7778
assert(itTypes.size() == itDomains.size());
7879

79-
// TODO: Add a parameter to the options?
80-
size_t totalSize = calcOperandsSize(op) * euThreads;
81-
unsigned loopCount = 0;
82-
80+
SmallVector<int64_t> tiles;
8381
for (auto [t, r] : zip(itTypes, itDomains)) {
8482
if (t == utils::IteratorType::parallel) {
8583
if (auto v = getConstantIntValue(r.size)) {
86-
loopCount++;
87-
totalSize *= *v;
84+
tiles.emplace_back(*v);
8885
} else {
89-
return calcDynamicSizes(builder, ti, euMem, euThreads);
86+
return calcDynamicSizes(builder, ti, cacheSize);
9087
}
9188
}
9289
}
9390

94-
if (loopCount == 0) {
91+
if (tiles.empty()) {
9592
return {};
9693
}
9794

98-
// TODO: In case of different sizes, calculate the ratio for each loop
99-
double ratio = std::pow(static_cast<double>(totalSize) /
100-
static_cast<double>(euMem),
101-
1.0 / loopCount);
102-
ratio = std::max(1.0, ratio);
103-
SmallVector<OpFoldResult> tiles;
104-
tiles.reserve(itDomains.size());
95+
int64_t elementSize = getElementSize(op);
96+
int64_t totalSize;
97+
98+
// If the operation could be lowered to XeGPU, make the tiles
99+
// proportional to the vector width. Otherwise, use the cache size.
100+
if (canLowerToXeGPU(op)) {
101+
totalSize = vectorWidth * numThreads / elementSize;
102+
} else {
103+
totalSize = cacheSize / elementSize;
104+
}
105+
106+
adjustTiles(totalSize, tiles);
107+
108+
unsigned counter = 0;
109+
SmallVector<OpFoldResult> result;
110+
result.reserve(itDomains.size());
105111

106112
for (auto [t, r] : zip(itTypes, itDomains)) {
107113
if (t != utils::IteratorType::parallel) {
108-
tiles.emplace_back(builder.getIndexAttr(1));
109-
} else if (auto v = getConstantIntValue(r.size)) {
110-
tiles.emplace_back(ceil(builder, *v, ratio));
114+
result.emplace_back(builder.getIndexAttr(1));
111115
} else {
112-
abort(); // Must never get here
116+
result.emplace_back(builder.getIndexAttr(tiles[counter++]));
113117
}
114118
}
115119

116-
return tiles;
120+
return result;
117121
});
118122

119123
auto fn = getOperation();
@@ -174,7 +178,8 @@ struct GpuTilingAndFusion final
174178
static std::optional<TilingInterface> findTi(Operation *op) {
175179
std::optional<TilingInterface> last;
176180
op->walk<WalkOrder::PreOrder>([&](linalg::LinalgOp linalgOp) {
177-
if (!linalgOp->getParentOfType<scf::ForallOp>()) {
181+
if (linalgOp.hasOnlyProjectedPermutations() &&
182+
!linalgOp->getParentOfType<scf::ForallOp>()) {
178183
if (auto ti = dyn_cast<TilingInterface>(linalgOp.getOperation())) {
179184
last = ti;
180185
}
@@ -184,17 +189,16 @@ struct GpuTilingAndFusion final
184189
return last;
185190
}
186191

187-
static SmallVector<OpFoldResult> calcDynamicSizes(OpBuilder &builder,
188-
TilingInterface ti,
189-
size_t euMem,
190-
size_t euThreads) {
192+
// TODO: Use the adjustTiles() function from MLIR.
193+
static SmallVector<OpFoldResult>
194+
calcDynamicSizes(OpBuilder &builder, TilingInterface ti, int64_t cacheSize) {
191195
auto itTypes = ti.getLoopIteratorTypes();
192196
auto itDomains = ti.getIterationDomain(builder);
193197
assert(itTypes.size() == itDomains.size());
194198

195199
auto loc = ti.getLoc();
196200
Value dynamicSize;
197-
size_t staticSize = calcOperandsSize(ti.getOperation()) * euThreads;
201+
int64_t staticSize = getElementSize(ti.getOperation());
198202
unsigned loopCount = 0;
199203

200204
for (auto [t, r] : zip(itTypes, itDomains)) {
@@ -225,7 +229,7 @@ struct GpuTilingAndFusion final
225229
dynamicSize));
226230

227231
auto memSize = builder.create<arith::ConstantFloatOp>(
228-
loc, APFloat(static_cast<double>(euMem)), builder.getF64Type());
232+
loc, APFloat(static_cast<double>(cacheSize)), builder.getF64Type());
229233
auto pow = builder.create<arith::ConstantFloatOp>(
230234
loc, APFloat(1.0 / loopCount), builder.getF64Type());
231235
Value ratio = builder.create<math::PowFOp>(
@@ -265,29 +269,63 @@ struct GpuTilingAndFusion final
265269
return tiles;
266270
}
267271

268-
static size_t calcOperandsSize(Operation *op) {
269-
size_t size = 0;
270-
auto typeSize = [](Type t) -> size_t {
271-
Type et;
272-
if (auto mt = dyn_cast<MemRefType>(t)) {
273-
et = mt.getElementType();
274-
} else if (auto tt = dyn_cast<TensorType>(t)) {
275-
et = tt.getElementType();
272+
static int64_t getElementSize(Operation *op) {
273+
int64_t elementSize = 1;
274+
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
275+
if (auto inits = linalgOp.getDpsInits(); !inits.empty()) {
276+
if (auto t = getElementTypeOrSelf(inits[0].getType());
277+
t.isIntOrFloat()) {
278+
elementSize = t.getIntOrFloatBitWidth() / 8;
279+
}
280+
}
281+
}
282+
return elementSize;
283+
}
284+
285+
// TODO: Add more checks
286+
static bool canLowerToXeGPU(Operation *operation) {
287+
auto op = dyn_cast<linalg::LinalgOp>(operation);
288+
if (!op) {
289+
return false;
290+
}
291+
if (op.hasDynamicShape()) {
292+
return false;
293+
}
294+
295+
auto checkOperand = [&](Value operand, bool isOutput = false) {
296+
ShapedType type;
297+
if (auto memref = dyn_cast<MemRefType>(operand.getType())) {
298+
type = memref;
299+
} else if (auto tensor = dyn_cast<RankedTensorType>(operand.getType())) {
300+
type = tensor;
276301
} else {
277-
return 0;
302+
return false;
278303
}
279-
return et.isIntOrFloat() ? et.getIntOrFloatBitWidth() / 8 : 1;
280-
};
281-
for (auto operand : op->getOperands()) {
282-
if (auto defOp = operand.getDefiningOp()) {
283-
for (auto t : defOp->getResultTypes()) {
284-
size += typeSize(t);
304+
305+
auto shape = type.getShape();
306+
if (isOutput) {
307+
if (shape.size() != 2 || shape[0] * shape[1] < 16) {
308+
return false;
285309
}
286-
} else {
287-
size += typeSize(operand.getType());
310+
} else if (shape.size() > 2) {
311+
return false;
288312
}
313+
314+
return true;
315+
};
316+
317+
if (auto inits = op.getDpsInits();
318+
!inits.empty() && !checkOperand(inits[0], true)) {
319+
return false;
289320
}
290-
return size == 0 ? 1 : size;
321+
322+
if (auto inputs = op.getDpsInputs();
323+
!std::all_of(inputs.begin(), inputs.end(),
324+
[&](Value v) { return checkOperand(v); })) {
325+
return false;
326+
}
327+
328+
return true;
291329
}
292330
};
293331
} // namespace

0 commit comments

Comments
 (0)