@@ -384,61 +384,45 @@ def format_of(ty):
384384 // valid nullptr
385385 return ptr_info;
386386 }}
387- PyObject* ptr = PyObject_GetAttr(obj, data_ptr_str);
388- if(ptr){{
389- PyObject *empty_tuple = PyTuple_New(0);
390- PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
391- Py_DECREF(empty_tuple);
392- Py_DECREF(ptr);
393- if (!PyLong_Check(ret)) {{
394- PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
395- ptr_info.valid = false;
396- return ptr_info;
397- }}
398- ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
399- if(!ptr_info.dev_ptr)
400- return ptr_info;
401- uint64_t dev_ptr;
402- int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
403- if (status == CUDA_ERROR_INVALID_VALUE) {{
404- PyErr_Format(PyExc_ValueError,
405- "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
406- ptr_info.valid = false;
407- }} else if (status != CUDA_SUCCESS) {{
408- CUDA_CHECK(status); // Catch any other cuda API errors
409- ptr_info.valid = false;
410- }}
411- ptr_info.dev_ptr = dev_ptr;
412- Py_DECREF(ret); // Thanks ChatGPT!
387+ PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
388+ if (!ret) {{
389+ PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
390+ ptr_info.valid = false;
391+ goto cleanup;
392+ }}
393+ if (!PyLong_Check(ret)) {{
394+ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
395+ ptr_info.valid = false;
396+ goto cleanup;
397+ }}
398+ ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
399+ if(!ptr_info.dev_ptr)
413400 return ptr_info;
401+ uint64_t dev_ptr;
402+ int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
403+ if (status == CUDA_ERROR_INVALID_VALUE) {{
404+ PyErr_Format(PyExc_ValueError,
405+ "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
406+ ptr_info.valid = false;
407+ }} else if (status != CUDA_SUCCESS) {{
408+ CUDA_CHECK(status); // Catch any other cuda API errors
409+ ptr_info.valid = false;
414410 }}
415- PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
416- ptr_info.valid = false;
411+ ptr_info.dev_ptr = dev_ptr;
412+ cleanup:
413+ Py_XDECREF(ret);
417414 return ptr_info;
415+
418416}}
419417
420418static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
421419 if (sizeof(CUtensorMap*) != 8) {{
422420 PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation");
423421 return NULL;
424422 }}
425- PyObject *method_handle = PyObject_GetAttr(obj, tma_desc_cpu_ptr_str);
426- if (!method_handle) {{
427- PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist");
428- return NULL;
429- }}
430423
431- PyObject *empty_tuple = PyTuple_New(0);
432- if (!empty_tuple) {{
433- Py_DECREF(method_handle);
434- PyErr_SetString(PyExc_SystemError, "Internal Python error!");
435- return NULL;
436- }}
437- PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL);
438- Py_DECREF(empty_tuple);
439- Py_DECREF(method_handle);
424+ PyObject *method_ret = PyObject_CallMethodNoArgs(obj, tma_desc_cpu_ptr_str);
440425 if (!method_ret) {{
441- PyErr_SetString(PyExc_SystemError, "Internal Python error!");
442426 return NULL;
443427 }}
444428
@@ -531,9 +515,7 @@ def format_of(ty):
531515
532516 // extract launch metadata
533517 if (launch_enter_hook != Py_None){{
534- PyObject* args = Py_BuildValue("(O)", launch_metadata);
535- PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
536- Py_DECREF(args);
518+ PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
537519 if (!ret)
538520 return NULL;
539521 Py_DECREF(ret);
@@ -569,8 +551,7 @@ def format_of(ty):
569551 }}
570552
571553 if(launch_exit_hook != Py_None){{
572- PyObject* args = Py_BuildValue("(O)", launch_metadata);
573- PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
554+ PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
574555 Py_DECREF(args);
575556 if (!ret)
576557 return NULL;
0 commit comments