You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We now tell LLVM and PTXAS the number of CTAs we are about to use.
This allows them to generate better code in most cases.
We also clean-up all the remaining uses of `cluster_dims` and we replace
them by queries to `num_ctas`.
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
443
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
444
444
hipDeviceptr_t global_scratch = 0;
445
445
void *params[] = {{ {', '.join(params)} }};
446
446
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
@@ -548,8 +548,8 @@ def format_of(ty):
548
548
{' '.join(float_storage_decls)}
549
549
550
550
// extract kernel metadata
551
-
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
552
-
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
551
+
int num_warps, num_ctas, shared_memory;
552
+
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
Copy file name to clipboardExpand all lines: third_party/nvidia/backend/driver.py
+10-16Lines changed: 10 additions & 16 deletions
Original file line number
Diff line number
Diff line change
@@ -1,5 +1,4 @@
1
1
importfunctools
2
-
importoperator
3
2
importos
4
3
importsubprocess
5
4
importtriton
@@ -314,7 +313,7 @@ def format_of(ty):
314
313
return cuLaunchKernelExHandle;
315
314
}}
316
315
317
-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
316
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
318
317
void *params[] = {{ {', '.join(params)} }};
319
318
if (gridX*gridY*gridZ > 0) {{
320
319
// 4 attributes that we can currently pass maximum
0 commit comments