2424#include " gc/Utils/Log.h"
2525
2626using namespace mlir ;
27- // using namespace mlir::gc::gpu ;
27+ using namespace mlir ::gc;
2828
2929namespace mlir ::gc {
3030#define GEN_PASS_DECL_GPUTILINGANDFUSION
@@ -39,33 +39,39 @@ struct GpuTilingAndFusion final
3939 gc::impl::GpuTilingAndFusionBase<GpuTilingAndFusion> {
4040 friend struct GpuPass ;
4141 explicit GpuTilingAndFusion ()
42- : GpuTilingAndFusion(gc:: GpuTilingAndFusionOptions{}) {}
43- explicit GpuTilingAndFusion (const gc:: GpuTilingAndFusionOptions &opts)
42+ : GpuTilingAndFusion(GpuTilingAndFusionOptions{}) {}
43+ explicit GpuTilingAndFusion (const GpuTilingAndFusionOptions &opts)
4444 : GpuPass(), GpuTilingAndFusionBase(opts) {}
4545
4646 void runOnOperation () override {
4747 IRRewriter rewriter (&getContext ());
4848 scf::SCFTileAndFuseOptions opts;
49+ opts.setFusionControlFn (
50+ [&](tensor::ExtractSliceOp, OpResult originalProducer, bool )
51+ -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
52+ Operation *op = originalProducer.getOwner ();
53+ if (!op) {
54+ return std::nullopt ;
55+ }
56+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
57+ if (!linalgOp.hasOnlyProjectedPermutations ()) {
58+ return std::nullopt ;
59+ }
60+ }
61+ return scf::SCFTileAndFuseOptions::ControlFnResult{};
62+ });
4963 opts.tilingOptions .setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
50- // The outer loop is converted to a GPU kernel and the tile sizes are mapped
51- // to the grid sizes.
64+ auto numEus = getNumEus (rewriter);
65+ auto numEusPerSlice = getNumEusPerSlice (rewriter);
66+ auto numThreadsPerEu = getNumThreadsPerEu (rewriter);
67+ auto cacheSize = getCacheSize (rewriter);
68+ auto vectorWidth = getVectorWidth (rewriter);
69+ auto cachePerThread =
70+ std::max (cacheSize / numEusPerSlice / numThreadsPerEu, vectorWidth);
71+ // The inner loop is converted to a GPU kernel and the tile sizes are mapped
72+ // to the block sizes.
5273 opts.tilingOptions .setTileSizeComputationFunction (
53- // The tile sizes calculation is based on the following equation:
54- // n * TS0 * TS1 * ... * TSn = euMem
55- // where:
56- // n - an average number of bytes, processed by each iteration
57- // TS0, TS1, ... TSn - the tile sizes for each loop correspondingly
58- // euMem - the physical memory (cache) size of the GPU execution unit
59- //
60- // To calculate the tile size TS, we need to divide the total loop size
61- // S by the ratio r:
62- //
63- // n * (S0/r0) * (S1/r1) * ... * (Sn/rn) = euMem
64- // r0 * r1 * ... * rn = (n * S0 * S1 * ... * Sn) / euMem
65- // If all sizes are equal, then S0 = ... = Sn = S, r0 = ... = rn = r:
66- // r^n = (n * S^n) / euMem
67- // r = (n * S^n / euMem)^(1/n)
68- [euMem = getEuMem (rewriter), euThreads = getEuThreads (rewriter)](
74+ [cachePerThread, vectorWidth, numThreads = numEus * numThreadsPerEu](
6975 OpBuilder &builder, Operation *op) -> SmallVector<OpFoldResult> {
7076 auto ti = dyn_cast<TilingInterface>(op);
7177 if (!ti) {
@@ -76,44 +82,49 @@ struct GpuTilingAndFusion final
7682 auto itDomains = ti.getIterationDomain (builder);
7783 assert (itTypes.size () == itDomains.size ());
7884
79- // TODO: Add a parameter to the options?
80- size_t totalSize = calcOperandsSize (op) * euThreads;
81- unsigned loopCount = 0 ;
82-
85+ SmallVector<int64_t > tiles;
86+ int64_t numIterations = 1 ;
8387 for (auto [t, r] : zip (itTypes, itDomains)) {
8488 if (t == utils::IteratorType::parallel) {
8589 if (auto v = getConstantIntValue (r.size )) {
86- loopCount++ ;
87- totalSize *= *v ;
90+ numIterations *= *v ;
91+ tiles. emplace_back (*v) ;
8892 } else {
89- return calcDynamicSizes (builder, ti, euMem, euThreads );
93+ return calcDynamicSizes (builder, ti, cachePerThread );
9094 }
9195 }
9296 }
9397
94- if (loopCount == 0 ) {
98+ if (tiles. empty () ) {
9599 return {};
96100 }
97101
98- // TODO: In case of different sizes, calculate the ratio for each loop
99- double ratio = std::pow (static_cast <double >(totalSize) /
100- static_cast <double >(euMem),
101- 1.0 / loopCount);
102- ratio = std::max (1.0 , ratio);
103- SmallVector<OpFoldResult> tiles;
104- tiles.reserve (itDomains.size ());
102+ auto elementSize = getElementSize (op);
103+ auto sizePerThread = numIterations / numThreads * elementSize;
104+ auto tilesSize = std::max (sizePerThread, cachePerThread);
105+ tilesSize = std::max (tilesSize / elementSize, 64L );
106+
107+ // If the operation could be lowered to XeGPU, make the tiles
108+ // proportional to the vector width.
109+ if (canLowerToXeGPU (op)) {
110+ tilesSize = std::max (tilesSize / vectorWidth, 1L ) * vectorWidth;
111+ }
112+
113+ adjustTiles (tilesSize, tiles);
114+
115+ unsigned counter = 0 ;
116+ SmallVector<OpFoldResult> result;
117+ result.reserve (itDomains.size ());
105118
106119 for (auto [t, r] : zip (itTypes, itDomains)) {
107120 if (t != utils::IteratorType::parallel) {
108- tiles.emplace_back (builder.getIndexAttr (1 ));
109- } else if (auto v = getConstantIntValue (r.size )) {
110- tiles.emplace_back (ceil (builder, *v, ratio));
121+ result.emplace_back (builder.getIndexAttr (1 ));
111122 } else {
112- abort (); // Must never get here
123+ result. emplace_back (builder. getIndexAttr (tiles[counter++]));
113124 }
114125 }
115126
116- return tiles ;
127+ return result ;
117128 });
118129
119130 auto fn = getOperation ();
@@ -174,7 +185,8 @@ struct GpuTilingAndFusion final
174185 static std::optional<TilingInterface> findTi (Operation *op) {
175186 std::optional<TilingInterface> last;
176187 op->walk <WalkOrder::PreOrder>([&](linalg::LinalgOp linalgOp) {
177- if (!linalgOp->getParentOfType <scf::ForallOp>()) {
188+ if (linalgOp.hasOnlyProjectedPermutations () &&
189+ !linalgOp->getParentOfType <scf::ForallOp>()) {
178190 if (auto ti = dyn_cast<TilingInterface>(linalgOp.getOperation ())) {
179191 last = ti;
180192 }
@@ -184,17 +196,16 @@ struct GpuTilingAndFusion final
184196 return last;
185197 }
186198
187- static SmallVector<OpFoldResult> calcDynamicSizes (OpBuilder &builder,
188- TilingInterface ti,
189- size_t euMem,
190- size_t euThreads) {
199+ // TODO: Use the adjustTiles() function from MLIR.
200+ static SmallVector<OpFoldResult>
201+ calcDynamicSizes (OpBuilder &builder, TilingInterface ti, int64_t cacheSize) {
191202 auto itTypes = ti.getLoopIteratorTypes ();
192203 auto itDomains = ti.getIterationDomain (builder);
193204 assert (itTypes.size () == itDomains.size ());
194205
195206 auto loc = ti.getLoc ();
196207 Value dynamicSize;
197- size_t staticSize = calcOperandsSize (ti.getOperation ()) * euThreads ;
208+ int64_t staticSize = getElementSize (ti.getOperation ());
198209 unsigned loopCount = 0 ;
199210
200211 for (auto [t, r] : zip (itTypes, itDomains)) {
@@ -225,7 +236,7 @@ struct GpuTilingAndFusion final
225236 dynamicSize));
226237
227238 auto memSize = builder.create <arith::ConstantFloatOp>(
228- loc, APFloat (static_cast <double >(euMem )), builder.getF64Type ());
239+ loc, APFloat (static_cast <double >(cacheSize )), builder.getF64Type ());
229240 auto pow = builder.create <arith::ConstantFloatOp>(
230241 loc, APFloat (1.0 / loopCount), builder.getF64Type ());
231242 Value ratio = builder.create <math::PowFOp>(
@@ -265,29 +276,63 @@ struct GpuTilingAndFusion final
265276 return tiles;
266277 }
267278
268- static size_t calcOperandsSize (Operation *op) {
269- size_t size = 0 ;
270- auto typeSize = [](Type t) -> size_t {
271- Type et;
272- if (auto mt = dyn_cast<MemRefType>(t)) {
273- et = mt.getElementType ();
274- } else if (auto tt = dyn_cast<TensorType>(t)) {
275- et = tt.getElementType ();
279+ static int64_t getElementSize (Operation *op) {
280+ int64_t elementSize = 1 ;
281+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
282+ if (auto inits = linalgOp.getDpsInits (); !inits.empty ()) {
283+ if (auto t = getElementTypeOrSelf (inits[0 ].getType ());
284+ t.isIntOrFloat ()) {
285+ elementSize = t.getIntOrFloatBitWidth () / 8 ;
286+ }
287+ }
288+ }
289+ return elementSize;
290+ }
291+
292+ // TODO: Add more checks
293+ static bool canLowerToXeGPU (Operation *operation) {
294+ auto op = dyn_cast<linalg::LinalgOp>(operation);
295+ if (!op) {
296+ return false ;
297+ }
298+ if (op.hasDynamicShape ()) {
299+ return false ;
300+ }
301+
302+ auto checkOperand = [&](Value operand, bool isOutput = false ) {
303+ ShapedType type;
304+ if (auto memref = dyn_cast<MemRefType>(operand.getType ())) {
305+ type = memref;
306+ } else if (auto tensor = dyn_cast<RankedTensorType>(operand.getType ())) {
307+ type = tensor;
276308 } else {
277- return 0 ;
309+ return false ;
278310 }
279- return et.isIntOrFloat () ? et.getIntOrFloatBitWidth () / 8 : 1 ;
280- };
281- for (auto operand : op->getOperands ()) {
282- if (auto defOp = operand.getDefiningOp ()) {
283- for (auto t : defOp->getResultTypes ()) {
284- size += typeSize (t);
311+
312+ auto shape = type.getShape ();
313+ if (isOutput) {
314+ if (shape.size () != 2 || shape[0 ] * shape[1 ] < 16 ) {
315+ return false ;
285316 }
286- } else {
287- size += typeSize (operand. getType ()) ;
317+ } else if (shape. size () > 2 ) {
318+ return false ;
288319 }
320+
321+ return true ;
322+ };
323+
324+ if (auto inits = op.getDpsInits ();
325+ !inits.empty () && !checkOperand (inits[0 ], true )) {
326+ return false ;
289327 }
290- return size == 0 ? 1 : size;
328+
329+ if (auto inputs = op.getDpsInputs ();
330+ !std::all_of (inputs.begin (), inputs.end (),
331+ [&](Value v) { return checkOperand (v); })) {
332+ return false ;
333+ }
334+
335+ return true ;
291336 }
292337};
293338} // namespace
0 commit comments