|
14 | 14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
15 | 15 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" |
16 | 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
17 | 18 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
18 | 19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
19 | 20 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
@@ -237,25 +238,17 @@ DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, |
237 | 238 | std::optional<int64_t> blockDimZ) { |
238 | 239 |
|
239 | 240 | // TODO: pass a configuration object to set the limits properly. |
240 | | - static constexpr int maxTotalBlockdim = 1024; |
241 | | - static constexpr int maxBlockdimx = 1024; |
242 | | - static constexpr int maxBlockdimy = 1024; |
243 | | - static constexpr int maxBlockdimz = 64; |
244 | | - static constexpr int maxTotalGriddim = 2147483647; |
245 | | - static constexpr int maxGriddimx = 2147483647; |
246 | | - static constexpr int maxGriddimy = 65535; |
247 | | - static constexpr int maxGriddimz = 65535; |
248 | 241 |
|
249 | 242 | if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) > |
250 | | - maxTotalBlockdim || |
| 243 | + kMaxTotalBlockdim || |
251 | 244 | (gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) > |
252 | | - maxTotalGriddim || |
253 | | - blockDimX.value_or(1) > maxBlockdimx || |
254 | | - blockDimY.value_or(1) > maxBlockdimy || |
255 | | - blockDimZ.value_or(1) > maxBlockdimz || |
256 | | - gridDimY.value_or(1) > maxGriddimy || |
257 | | - gridDimZ.value_or(1) > maxGriddimz || |
258 | | - gridDimX.value_or(1) > maxGriddimx) { |
| 245 | + kMaxTotalGriddim || |
| 246 | + blockDimX.value_or(1) > kMaxBlockdimx || |
| 247 | + blockDimY.value_or(1) > kMaxBlockdimy || |
| 248 | + blockDimZ.value_or(1) > kMaxBlockdimz || |
| 249 | + gridDimY.value_or(1) > kMaxGriddimy || |
| 250 | + gridDimZ.value_or(1) > kMaxGriddimz || |
| 251 | + gridDimX.value_or(1) > kMaxGriddimx) { |
259 | 252 | return transformOp.emitSilenceableError() |
260 | 253 | << "Trying to launch a GPU kernel with grid_dims = (" |
261 | 254 | << gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", " |
|
0 commit comments