Skip to content

Commit 7bafc7f

Browse files
agron911meta-codesync[bot]
authored andcommitted
[triton][PR] [triton][Fix] Back out "Revert D86123088", Fix old TMA API not supported issue
Summary: Original commit changeset: 438f340d41ee Original Phabricator Diff: D86533448 Import the fix from github relese/3.5.x branch [release 3.5.x] Fix old TMA API not supported issue #653 Reviewed By: htyu, dshi7 Differential Revision: D86783598 fbshipit-source-id: e66895007c08e45b9ef8047949aa2f3bc8eee6ff
1 parent 341b78b commit 7bafc7f

File tree

2 files changed

+194
-79
lines changed

2 files changed

+194
-79
lines changed

third_party/nvidia/backend/driver.c

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
#include "cuda.h"
22
#include <dlfcn.h>
33
#include <stdbool.h>
4+
#include <stdlib.h>
45
#define PY_SSIZE_T_CLEAN
56
#include <Python.h>
67

8+
typedef struct {
9+
PyObject_HEAD;
10+
_Alignas(128) CUtensorMap tensorMap;
11+
} PyCUtensorMapObject;
12+
713
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
814
static bool gpuAssert(CUresult code, const char *file, int line) {
915
if (code == CUDA_SUCCESS)
@@ -26,7 +32,7 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
2632
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
2733
do { \
2834
if (!gpuAssert((ans), __FILE__, __LINE__)) \
29-
return NULL; \
35+
goto cleanup; \
3036
} while (0)
3137

3238
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
@@ -44,7 +50,7 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
4450
if ((funcPointer) == NULL) { \
4551
(funcPointer) = (initializerFunction)(); \
4652
if ((funcPointer) == NULL) { \
47-
return NULL; \
53+
goto cleanup; \
4854
} \
4955
} \
5056
} while (0)
@@ -87,6 +93,9 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
8793
warp_size, "sm_clock_rate", sm_clock_rate,
8894
"mem_clock_rate", mem_clock_rate, "mem_bus_width",
8995
mem_bus_width);
96+
97+
cleanup:
98+
return NULL;
9099
}
91100

92101
static PyObject *loadBinary(PyObject *self, PyObject *args) {
@@ -238,6 +247,9 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
238247
cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config));
239248
Py_END_ALLOW_THREADS;
240249
return PyLong_FromLong(maxActiveClusters);
250+
251+
cleanup:
252+
return NULL;
241253
}
242254

243255
static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
@@ -279,8 +291,43 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
279291
Py_RETURN_NONE;
280292
}
281293

294+
static PyObject *PyCUtensorMap_alloc(PyTypeObject *type, Py_ssize_t n_items) {
295+
PyCUtensorMapObject *self = NULL;
296+
void *mem = NULL;
297+
size_t size = type->tp_basicsize;
298+
299+
if (posix_memalign(&mem, 128, size) != 0) {
300+
PyErr_NoMemory();
301+
return NULL;
302+
}
303+
304+
self = (PyCUtensorMapObject *)mem;
305+
PyObject_INIT(self, type);
306+
return (PyObject *)self;
307+
}
308+
309+
static void PyCUtensorMap_dealloc(PyObject *self) {
310+
Py_TYPE(self)->tp_free(self);
311+
}
312+
313+
static void PyCUtensorMap_free(void *ptr) { free(ptr); }
314+
315+
// clang-format off
316+
static PyTypeObject PyCUtensorMapType = {
317+
PyVarObject_HEAD_INIT(NULL, 0)
318+
.tp_name = "triton.backends.nvidia.PyCUtensorMap",
319+
.tp_basicsize = sizeof(PyCUtensorMapObject),
320+
.tp_itemsize = 0,
321+
.tp_flags = Py_TPFLAGS_DEFAULT,
322+
.tp_doc = "<PyCUtensorMap object>",
323+
.tp_new = PyType_GenericNew,
324+
.tp_alloc = PyCUtensorMap_alloc,
325+
.tp_dealloc = (destructor)PyCUtensorMap_dealloc,
326+
.tp_free = PyCUtensorMap_free,
327+
};
328+
// clang-format on
329+
282330
static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
283-
unsigned long long desc_address;
284331
unsigned long long global_address;
285332
int swizzle;
286333
int elemSize;
@@ -290,16 +337,20 @@ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
290337
PyObject *strides;
291338
int padding;
292339

293-
if (!PyArg_ParseTuple(args, "KKiiiOOOi", &desc_address, &global_address,
294-
&swizzle, &elemSize, &elemType, &blockSize, &shape,
295-
&strides, &padding)) {
340+
if (!PyArg_ParseTuple(args, "KiiiOOOi", &global_address, &swizzle, &elemSize,
341+
&elemType, &blockSize, &shape, &strides, &padding)) {
342+
return NULL;
343+
}
344+
345+
PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
346+
(PyObject *)&PyCUtensorMapType, NULL);
347+
if (!desc) {
296348
return NULL;
297349
}
298350

299351
PyObject *blockSizeFast = NULL;
300352
PyObject *shapeFast = NULL;
301353
PyObject *stridesFast = NULL;
302-
PyObject *result = NULL;
303354

304355
uint32_t blockSizeInt[5];
305356
uint64_t shapeInt[5];
@@ -370,17 +421,18 @@ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
370421
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
371422
getCuTensorMapEncodeTiledHandle);
372423
CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
373-
(CUtensorMap *)desc_address, elemType, rank, (void *)global_address,
374-
shapeInt, stridesLL, blockSizeInt, elementStrides,
375-
CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
376-
CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill));
377-
Py_RETURN_NONE;
424+
&desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
425+
stridesLL, blockSizeInt, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
426+
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill));
427+
428+
return (PyObject *)desc;
378429

379430
cleanup:
380431
Py_XDECREF(blockSizeFast);
381432
Py_XDECREF(shapeFast);
382433
Py_XDECREF(stridesFast);
383-
return result;
434+
Py_XDECREF(desc);
435+
return NULL;
384436
}
385437

386438
// Simple helper to experiment creating TMA descriptors on the host.
@@ -426,6 +478,8 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
426478
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
427479
Py_INCREF(Py_None);
428480
return Py_None;
481+
cleanup:
482+
return NULL;
429483
}
430484

431485
// Simple helper to experiment creating TMA descriptors on the host.
@@ -490,6 +544,8 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) {
490544
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
491545
Py_INCREF(Py_None);
492546
return Py_None;
547+
cleanup:
548+
return NULL;
493549
}
494550

495551
// Simple helper to experiment creating TMA descriptors on the host.
@@ -545,6 +601,8 @@ static PyObject *fill1DTMADescriptorType(PyObject *self, PyObject *args) {
545601
Py_INCREF(Py_None);
546602
#endif
547603
return Py_None;
604+
cleanup:
605+
return NULL;
548606
}
549607

550608
// Simple helper to experiment creating TMA descriptors on the host.
@@ -619,6 +677,8 @@ static PyObject *fill2DTMADescriptorType(PyObject *self, PyObject *args) {
619677
Py_INCREF(Py_None);
620678
#endif
621679
return Py_None;
680+
cleanup:
681+
return NULL;
622682
}
623683

624684
static PyMethodDef ModuleMethods[] = {
@@ -651,12 +711,18 @@ static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
651711
ModuleMethods};
652712

653713
PyMODINIT_FUNC PyInit_cuda_utils(void) {
714+
if (PyType_Ready(&PyCUtensorMapType) < 0) {
715+
return NULL;
716+
}
717+
654718
PyObject *m = PyModule_Create(&ModuleDef);
655719
if (m == NULL) {
656720
return NULL;
657721
}
658722

659723
PyModule_AddFunctions(m, ModuleMethods);
724+
Py_INCREF(&PyCUtensorMapType);
725+
PyModule_AddObject(m, "PyCUtensorMap", (PyObject *)&PyCUtensorMapType);
660726

661727
return m;
662728
}

0 commit comments

Comments
 (0)