|
15 | 15 | #include "flang/Optimizer/Dialect/CUF/CUFDialect.h" |
16 | 16 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
17 | 17 | #include "flang/Optimizer/Dialect/FIRType.h" |
| 18 | +#include "mlir/Dialect/GPU/IR/GPUDialect.h" |
18 | 19 | #include "mlir/IR/Attributes.h" |
19 | 20 | #include "mlir/IR/BuiltinAttributes.h" |
20 | 21 | #include "mlir/IR/BuiltinOps.h" |
@@ -253,6 +254,42 @@ llvm::LogicalResult cuf::KernelOp::verify() { |
253 | 254 | return mlir::success(); |
254 | 255 | } |
255 | 256 |
|
| 257 | +//===----------------------------------------------------------------------===// |
| 258 | +// RegisterKernelOp |
| 259 | +//===----------------------------------------------------------------------===// |
| 260 | + |
| 261 | +mlir::StringAttr cuf::RegisterKernelOp::getKernelModuleName() { |
| 262 | + return getName().getRootReference(); |
| 263 | +} |
| 264 | + |
| 265 | +mlir::StringAttr cuf::RegisterKernelOp::getKernelName() { |
| 266 | + return getName().getLeafReference(); |
| 267 | +} |
| 268 | + |
| 269 | +mlir::LogicalResult cuf::RegisterKernelOp::verify() { |
| 270 | + if (getKernelName() == getKernelModuleName()) |
| 271 | + return emitOpError("expect a module and a kernel name"); |
| 272 | + |
| 273 | + auto mod = getOperation()->getParentOfType<mlir::ModuleOp>(); |
| 274 | + if (!mod) |
| 275 | + return emitOpError("expect to be in a module"); |
| 276 | + |
| 277 | + mlir::SymbolTable symTab(mod); |
| 278 | + auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(getKernelModuleName()); |
| 279 | + if (!gpuMod) |
| 280 | + return emitOpError("gpu module not found"); |
| 281 | + |
| 282 | + mlir::SymbolTable gpuSymTab(gpuMod); |
| 283 | + auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName()); |
| 284 | + if (!func) |
| 285 | + return emitOpError("device function not found"); |
| 286 | + |
| 287 | + if (!func.isKernel()) |
| 288 | + return emitOpError("only kernel gpu.func can be registered"); |
| 289 | + |
| 290 | + return mlir::success(); |
| 291 | +} |
| 292 | + |
256 | 293 | // Tablegen operators |
257 | 294 |
|
258 | 295 | #define GET_OP_CLASSES |
|
0 commit comments