|
13 | 13 | This module provides interoperability with the following python packages.
|
14 | 14 |
|
15 | 15 | 1. numpy
|
| 16 | + 2. pycuda |
16 | 17 | """
|
17 | 18 |
|
18 | 19 | from .array import *
|
| 20 | +from .device import lock_array |
19 | 21 |
|
20 | 22 | try:
|
21 | 23 | import numpy as np
|
@@ -60,3 +62,48 @@ def np_to_af_array(np_arr):
|
60 | 62 | from_ndarray = np_to_af_array
|
61 | 63 | except:
|
62 | 64 | AF_NUMPY_FOUND=False
|
| 65 | + |
| 66 | +try: |
| 67 | + import pycuda.gpuarray as CudaArray |
| 68 | + AF_PYCUDA_FOUND=True |
| 69 | + |
| 70 | + def pycuda_to_af_array(pycu_arr): |
| 71 | + """ |
| 72 | + Convert pycuda.gpuarray to arrayfire.Array |
| 73 | +
|
| 74 | + Parameters |
| 75 | + ----------- |
| 76 | + pycu_arr : pycuda.GPUArray() |
| 77 | +
|
| 78 | + Returns |
| 79 | + ---------- |
| 80 | + af_arr : arrayfire.Array() |
| 81 | + """ |
| 82 | + if (pycu_arr.flags.f_contiguous): |
| 83 | + res = Array(pycu_arr.ptr, pycu_arr.shape, pycu_arr.dtype.char, is_device=True) |
| 84 | + lock_array(res) |
| 85 | + return res |
| 86 | + elif (pycu_arr.flags.c_contiguous): |
| 87 | + if pycu_arr.ndim == 1: |
| 88 | + return Array(pycu_arr.ptr, pycu_arr.shape, pycu_arr.dtype.char, is_device=True) |
| 89 | + elif pycu_arr.ndim == 2: |
| 90 | + shape = (pycu_arr.shape[1], pycu_arr.shape[0]) |
| 91 | + res = Array(pycu_arr.ptr, shape, pycu_arr.dtype.char, is_device=True) |
| 92 | + lock_array(res) |
| 93 | + return reorder(res, 1, 0) |
| 94 | + elif pycu_arr.ndim == 3: |
| 95 | + shape = (pycu_arr.shape[2], pycu_arr.shape[1], pycu_arr.shape[0]) |
| 96 | + res = Array(pycu_arr.ptr, shape, pycu_arr.dtype.char, is_device=True) |
| 97 | + lock_array(res) |
| 98 | + return reorder(res, 2, 1, 0) |
| 99 | + elif pycu_arr.ndim == 4: |
| 100 | + shape = (pycu_arr.shape[3], pycu_arr.shape[2], pycu_arr.shape[1], pycu_arr.shape[0]) |
| 101 | + res = Array(pycu_arr.ptr, shape, pycu_arr.dtype.char, is_device=True) |
| 102 | + lock_array(res) |
| 103 | + return reorder(res, 3, 2, 1, 0) |
| 104 | + else: |
| 105 | + raise RuntimeError("Unsupported ndim") |
| 106 | + else: |
| 107 | + return pycuda_to_af_array(pycu_arr.copy()) |
| 108 | +except: |
| 109 | + AF_PYCUDA_FOUND=False |
0 commit comments