Skip to content

Commit 198fd9b

Browse files
authored
[Runtime] CPython cleanup and PyObject_Call overhead reduction (#7791)
This does a few things: - Changes generic `PyObject_Call` to optimized versions: `PyObject_CallOneArg`, `PyObject_CallMethodNoArgs`. Result 17.85 us -> 17.39 us launch overhead. - Use `goto cleanup` pattern in `getPointer`, which avoids leaking the `ret` value in error handling code paths. - Change `Py_INCREF(Py_None); return Py_None` to `Py_RETURN_NONE` which does the same thing. - Port `"data_ptr"` string interning to AMD launcher. I recommend viewing the diff [ignoring whitespace changes](https://github.com/triton-lang/triton/pull/7791/files?diff=unified&w=1) since I de-indented part of `getPointer`.
1 parent 2453088 commit 198fd9b

File tree

4 files changed

+62
-86
lines changed

4 files changed

+62
-86
lines changed

python/test/backend/test_device_backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,7 @@ def _generate_launcher(self, constants, signature):
176176
return NULL;
177177
}
178178
launch_counter(self, args);
179-
// return None
180-
Py_INCREF(Py_None);
181-
return Py_None;
179+
Py_RETURN_NONE;
182180
}
183181
184182
static PyMethodDef ModuleMethods[] = {

third_party/amd/backend/driver.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,8 @@ def format_of(ty):
427427
bool valid;
428428
}} DevicePtrInfo;
429429
430+
static PyObject* data_ptr_str = NULL;
431+
430432
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
431433
DevicePtrInfo ptr_info;
432434
ptr_info.dev_ptr = 0;
@@ -439,32 +441,30 @@ def format_of(ty):
439441
// valid nullptr
440442
return ptr_info;
441443
}}
442-
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
443-
if(ptr){{
444-
PyObject *empty_tuple = PyTuple_New(0);
445-
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
446-
Py_DECREF(empty_tuple);
447-
Py_DECREF(ptr);
448-
if (!PyLong_Check(ret)) {{
449-
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
444+
PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
445+
if (!ret) {{
446+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
447+
ptr_info.valid = false;
448+
goto cleanup;
449+
}}
450+
if (!PyLong_Check(ret)) {{
451+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
452+
ptr_info.valid = false;
453+
goto cleanup;
454+
}}
455+
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
456+
if (!ptr_info.dev_ptr)
457+
goto cleanup;
458+
uint64_t dev_ptr;
459+
hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
460+
if (status == hipErrorInvalidValue) {{
461+
PyErr_Format(PyExc_ValueError,
462+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
450463
ptr_info.valid = false;
451-
return ptr_info;
452-
}}
453-
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
454-
if(!ptr_info.dev_ptr)
455-
return ptr_info;
456-
uint64_t dev_ptr;
457-
hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
458-
if (status == hipErrorInvalidValue) {{
459-
PyErr_Format(PyExc_ValueError,
460-
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
461-
ptr_info.valid = false;
462-
}}
463-
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
464-
Py_DECREF(ret);
465-
return ptr_info;
466464
}}
467-
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
465+
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
466+
cleanup:
467+
Py_DECREF(ret);
468468
return ptr_info;
469469
}}
470470
@@ -521,9 +521,7 @@ def format_of(ty):
521521
}}
522522
// extract launch metadata
523523
if (launch_enter_hook != Py_None){{
524-
PyObject* args = Py_BuildValue("(O)", launch_metadata);
525-
PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
526-
Py_DECREF(args);
524+
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
527525
if (!ret)
528526
return NULL;
529527
Py_DECREF(ret);
@@ -543,9 +541,7 @@ def format_of(ty):
543541
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
544542
545543
if(launch_exit_hook != Py_None){{
546-
PyObject* args = Py_BuildValue("(O)", launch_metadata);
547-
PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
548-
Py_DECREF(args);
544+
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
549545
if (!ret)
550546
return NULL;
551547
Py_DECREF(ret);
@@ -554,9 +550,7 @@ def format_of(ty):
554550
if(PyErr_Occurred()) {{
555551
return NULL;
556552
}}
557-
// return None
558-
Py_INCREF(Py_None);
559-
return Py_None;
553+
Py_RETURN_NONE;
560554
}}
561555
562556
static PyMethodDef ModuleMethods[] = {{
@@ -580,6 +574,10 @@ def format_of(ty):
580574
if(m == NULL) {{
581575
return NULL;
582576
}}
577+
data_ptr_str = PyUnicode_InternFromString("data_ptr");
578+
if(data_ptr_str == NULL) {{
579+
return NULL;
580+
}}
583581
PyModule_AddFunctions(m, ModuleMethods);
584582
return m;
585583
}}

third_party/nvidia/backend/driver.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,7 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
276276
}
277277

278278
Py_END_ALLOW_THREADS;
279-
Py_INCREF(Py_None);
280-
return Py_None;
279+
Py_RETURN_NONE;
281280
}
282281

283282
static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {

third_party/nvidia/backend/driver.py

Lines changed: 29 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -384,61 +384,45 @@ def format_of(ty):
384384
// valid nullptr
385385
return ptr_info;
386386
}}
387-
PyObject* ptr = PyObject_GetAttr(obj, data_ptr_str);
388-
if(ptr){{
389-
PyObject *empty_tuple = PyTuple_New(0);
390-
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
391-
Py_DECREF(empty_tuple);
392-
Py_DECREF(ptr);
393-
if (!PyLong_Check(ret)) {{
394-
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
395-
ptr_info.valid = false;
396-
return ptr_info;
397-
}}
398-
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
399-
if(!ptr_info.dev_ptr)
400-
return ptr_info;
401-
uint64_t dev_ptr;
402-
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
403-
if (status == CUDA_ERROR_INVALID_VALUE) {{
404-
PyErr_Format(PyExc_ValueError,
405-
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
406-
ptr_info.valid = false;
407-
}} else if (status != CUDA_SUCCESS) {{
408-
CUDA_CHECK(status); // Catch any other cuda API errors
409-
ptr_info.valid = false;
410-
}}
411-
ptr_info.dev_ptr = dev_ptr;
412-
Py_DECREF(ret); // Thanks ChatGPT!
387+
PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
388+
if (!ret) {{
389+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
390+
ptr_info.valid = false;
391+
goto cleanup;
392+
}}
393+
if (!PyLong_Check(ret)) {{
394+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
395+
ptr_info.valid = false;
396+
goto cleanup;
397+
}}
398+
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
399+
if(!ptr_info.dev_ptr)
413400
return ptr_info;
401+
uint64_t dev_ptr;
402+
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
403+
if (status == CUDA_ERROR_INVALID_VALUE) {{
404+
PyErr_Format(PyExc_ValueError,
405+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
406+
ptr_info.valid = false;
407+
}} else if (status != CUDA_SUCCESS) {{
408+
CUDA_CHECK(status); // Catch any other cuda API errors
409+
ptr_info.valid = false;
414410
}}
415-
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
416-
ptr_info.valid = false;
411+
ptr_info.dev_ptr = dev_ptr;
412+
cleanup:
413+
Py_XDECREF(ret);
417414
return ptr_info;
415+
418416
}}
419417
420418
static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
421419
if (sizeof(CUtensorMap*) != 8) {{
422420
PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation");
423421
return NULL;
424422
}}
425-
PyObject *method_handle = PyObject_GetAttr(obj, tma_desc_cpu_ptr_str);
426-
if (!method_handle) {{
427-
PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist");
428-
return NULL;
429-
}}
430423
431-
PyObject *empty_tuple = PyTuple_New(0);
432-
if (!empty_tuple) {{
433-
Py_DECREF(method_handle);
434-
PyErr_SetString(PyExc_SystemError, "Internal Python error!");
435-
return NULL;
436-
}}
437-
PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL);
438-
Py_DECREF(empty_tuple);
439-
Py_DECREF(method_handle);
424+
PyObject *method_ret = PyObject_CallMethodNoArgs(obj, tma_desc_cpu_ptr_str);
440425
if (!method_ret) {{
441-
PyErr_SetString(PyExc_SystemError, "Internal Python error!");
442426
return NULL;
443427
}}
444428
@@ -531,9 +515,7 @@ def format_of(ty):
531515
532516
// extract launch metadata
533517
if (launch_enter_hook != Py_None){{
534-
PyObject* args = Py_BuildValue("(O)", launch_metadata);
535-
PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
536-
Py_DECREF(args);
518+
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
537519
if (!ret)
538520
return NULL;
539521
Py_DECREF(ret);
@@ -569,8 +551,7 @@ def format_of(ty):
569551
}}
570552
571553
if(launch_exit_hook != Py_None){{
572-
PyObject* args = Py_BuildValue("(O)", launch_metadata);
573-
PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
554+
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
574555
Py_DECREF(args);
575556
if (!ret)
576557
return NULL;

0 commit comments

Comments
 (0)