|
28 | 28 | #include <mlir/Dialect/Func/IR/FuncOps.h>
|
29 | 29 | #include <mlir/Dialect/GPU/Transforms/Passes.h>
|
30 | 30 | #include <mlir/Dialect/MemRef/IR/MemRef.h>
|
| 31 | +#include <mlir/Dialect/XeGPU/IR/XeGPU.h> |
31 | 32 | #include <mlir/Dialect/SCF/IR/SCF.h>
|
32 | 33 | #include <mlir/Pass/Pass.h>
|
33 | 34 |
|
@@ -158,6 +159,9 @@ class InsertGPUAllocsPass final
|
158 | 159 | } else if (auto init_tile =
|
159 | 160 | mlir::dyn_cast<imex::xetile::InitTileOp>(op)) {
|
160 | 161 | return {{init_tile.getSource()}};
|
| 162 | + } else if (auto init_xedesc = |
| 163 | + mlir::dyn_cast<mlir::xegpu::CreateNdDescOp>(op)) { |
| 164 | + return {{init_xedesc.getSource()}}; |
161 | 165 | } else {
|
162 | 166 | op->emitError("Uhhandled mem op in gpu region");
|
163 | 167 | return std::nullopt;
|
@@ -187,6 +191,9 @@ class InsertGPUAllocsPass final
|
187 | 191 | // Only handle the case where the tile source is a memref
|
188 | 192 | return init_tile.isSourceMemRef();
|
189 | 193 | }
|
| 194 | + if (auto init_xedesc = mlir::dyn_cast<mlir::xegpu::CreateNdDescOp>(op)) { |
| 195 | + return true; |
| 196 | + } |
190 | 197 | return false;
|
191 | 198 | };
|
192 | 199 |
|
@@ -259,6 +266,36 @@ class InsertGPUAllocsPass final
|
259 | 266 | return;
|
260 | 267 | }
|
261 | 268 |
|
| 269 | + // walk over the users and find xegpu.load/store ops |
| 270 | + std::function<void(mlir::Operation*, bool, AccessType&)> findXeGPULoadStore; |
| 271 | + findXeGPULoadStore = [&](mlir::Operation *use, bool onDevice, AccessType& ret) { |
| 272 | + if (auto tile_update = mlir::dyn_cast<mlir::xegpu::UpdateNdOffsetOp>(use)) { |
| 273 | + auto res = tile_update->getResult(0); |
| 274 | + for (auto u : res.getUsers()) { |
| 275 | + findXeGPULoadStore(u, onDevice, ret); |
| 276 | + } |
| 277 | + } |
| 278 | + if (auto tile_for = mlir::dyn_cast<::mlir::scf::ForOp>(use)) { |
| 279 | + for (size_t idx=0; idx<tile_for.getInits().size(); idx++) { |
| 280 | + auto a = tile_for.getRegionIterArg(idx); |
| 281 | + for (auto u : a.getUsers()) { |
| 282 | + findXeGPULoadStore(u, onDevice, ret); |
| 283 | + } |
| 284 | + } |
| 285 | + } |
| 286 | + if (auto tile_load = |
| 287 | + mlir::dyn_cast<mlir::xegpu::LoadNdOp>(use)) { |
| 288 | + (onDevice ? ret.deviceRead : ret.hostRead) = true; |
| 289 | + } |
| 290 | + else if (auto tile_prefetch = |
| 291 | + mlir::dyn_cast<mlir::xegpu::PrefetchNdOp>(use)) { |
| 292 | + (onDevice ? ret.deviceRead : ret.hostRead) = true; |
| 293 | + } else if (auto tile_store = |
| 294 | + mlir::dyn_cast<mlir::xegpu::StoreNdOp>(use)) { |
| 295 | + (onDevice ? ret.deviceWrite : ret.hostWrite) = true; |
| 296 | + } |
| 297 | + }; |
| 298 | + |
262 | 299 | // Checks the access type of the OP under consideration.
|
263 | 300 | auto getAccessType = [&](mlir::Value memref) {
|
264 | 301 | AccessType ret;
|
@@ -298,6 +335,15 @@ class InsertGPUAllocsPass final
|
298 | 335 | continue;
|
299 | 336 | }
|
300 | 337 |
|
| 338 | + if (auto init_xedesc = mlir::dyn_cast<mlir::xegpu::CreateNdDescOp>(user)) { |
| 339 | + bool onDevice = user->getParentOfType<mlir::gpu::LaunchOp>(); |
| 340 | + auto res = init_xedesc->getResult(0); |
| 341 | + for (auto use : res.getUsers()) { |
| 342 | + findXeGPULoadStore(use, onDevice, ret); |
| 343 | + } |
| 344 | + continue; |
| 345 | + } |
| 346 | + |
301 | 347 | if (mlir::isa<mlir::func::ReturnOp>(user)) {
|
302 | 348 | ret.hostRead = true;
|
303 | 349 | ret.hostWrite = true;
|
|
0 commit comments