Skip to content

Commit 196f110

Browse files
committed
Adding function to interop with pycuda
1 parent 5f4e860 commit 196f110

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

arrayfire/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
3636
"""
3737

38+
try:
39+
import pycuda.autoinit
40+
except:
41+
pass
42+
3843
from .library import *
3944
from .array import *
4045
from .data import *

arrayfire/interop.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
This module provides interoperability with the following python packages.
1414
1515
1. numpy
16+
2. pycuda
1617
"""
1718

1819
from .array import *
20+
from .device import lock_array
1921

2022
try:
2123
import numpy as np
@@ -60,3 +62,48 @@ def np_to_af_array(np_arr):
6062
from_ndarray = np_to_af_array
6163
except:
6264
AF_NUMPY_FOUND=False
65+
66+
try:
67+
import pycuda.gpuarray as CudaArray
68+
AF_PYCUDA_FOUND=True
69+
70+
def pycuda_to_af_array(pycu_arr):
71+
"""
72+
Convert pycuda.gpuarray to arrayfire.Array
73+
74+
Parameters
75+
-----------
76+
pycu_arr : pycuda.GPUArray()
77+
78+
Returns
79+
----------
80+
af_arr : arrayfire.Array()
81+
"""
82+
if (pycu_arr.flags.f_contiguous):
83+
res = Array(pycu_arr.ptr, pycu_arr.shape, pycu_arr.dtype.char, is_device=True)
84+
lock_array(res)
85+
return res
86+
elif (pycu_arr.flags.c_contiguous):
87+
if pycu_arr.ndim == 1:
88+
return Array(pycu_arr.ptr, pycu_arr.shape, pycu_arr.dtype.char, is_device=True)
89+
elif pycu_arr.ndim == 2:
90+
shape = (pycu_arr.shape[1], pycu_arr.shape[0])
91+
res = Array(pycu_arr.ptr, shape, pycu_arr.dtype.char, is_device=True)
92+
lock_array(res)
93+
return reorder(res, 1, 0)
94+
elif pycu_arr.ndim == 3:
95+
shape = (pycu_arr.shape[2], pycu_arr.shape[1], pycu_arr.shape[0])
96+
res = Array(pycu_arr.ptr, shape, pycu_arr.dtype.char, is_device=True)
97+
lock_array(res)
98+
return reorder(res, 2, 1, 0)
99+
elif pycu_arr.ndim == 4:
100+
shape = (pycu_arr.shape[3], pycu_arr.shape[2], pycu_arr.shape[1], pycu_arr.shape[0])
101+
res = Array(pycu_arr.ptr, shape, pycu_arr.dtype.char, is_device=True)
102+
lock_array(res)
103+
return reorder(res, 3, 2, 1, 0)
104+
else:
105+
raise RuntimeError("Unsupported ndim")
106+
else:
107+
return pycuda_to_af_array(pycu_arr.copy())
108+
except:
109+
AF_PYCUDA_FOUND=False

0 commit comments

Comments
 (0)