You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
443
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
444
444
hipDeviceptr_t global_scratch = 0;
445
445
void *params[] = {{ {', '.join(params)} }};
446
446
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
@@ -548,8 +548,8 @@ def format_of(ty):
548
548
{' '.join(float_storage_decls)}
549
549
550
550
// extract kernel metadata
551
-
int num_warps, num_ctas, shared_memory;
552
-
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
551
+
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
552
+
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
Copy file name to clipboardExpand all lines: third_party/nvidia/backend/driver.py
+16-10Lines changed: 16 additions & 10 deletions
Original file line number
Diff line number
Diff line change
@@ -1,4 +1,5 @@
1
1
importfunctools
2
+
importoperator
2
3
importos
3
4
importsubprocess
4
5
importtriton
@@ -338,7 +339,7 @@ def format_of(ty):
338
339
}}
339
340
#endif
340
341
341
-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
342
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
342
343
void *params[] = {{ {', '.join(params)} }};
343
344
if (gridX*gridY*gridZ > 0) {{
344
345
// 4 attributes that we can currently pass maximum
0 commit comments