You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We now tell LLVM and PTXAS the number of CTAs we are about to use.
This allows them to generate better code in most cases.
We also clean-up all the remaining uses of `cluster_dims` and we replace
them by queries to `num_ctas`.
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
490
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
491
491
if (gridX * gridY * gridZ == 0)
492
492
return;
493
493
hipDeviceptr_t global_scratch = 0;
@@ -632,8 +632,8 @@ def format_of(ty):
632
632
}}
633
633
634
634
// extract kernel metadata
635
-
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
636
-
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
635
+
int num_warps, num_ctas, shared_memory;
636
+
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
Copy file name to clipboardExpand all lines: third_party/nvidia/backend/driver.py
+10-16Lines changed: 10 additions & 16 deletions
Original file line number
Diff line number
Diff line change
@@ -1,5 +1,4 @@
1
1
importfunctools
2
-
importoperator
3
2
importos
4
3
importsubprocess
5
4
importtriton
@@ -339,7 +338,7 @@ def format_of(ty):
339
338
}}
340
339
#endif
341
340
342
-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
341
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
343
342
void *params[] = {{ {', '.join(params)} }};
344
343
if (gridX*gridY*gridZ > 0) {{
345
344
// 4 attributes that we can currently pass maximum
0 commit comments