Skip to content

Commit 8286a89

Browse files
committed
work in progress
1 parent d6ba6ad commit 8286a89

File tree

5 files changed

+60
-16
lines changed

5 files changed

+60
-16
lines changed

pyccel/ast/cudatypes.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ class CudaArrayType(HomogeneousContainerType, metaclass = ArgumentSingleton):
3232
"""
3333
__slots__ = ('_element_type', '_container_rank', '_order', '_memory_location')
3434

35-
# def __new__(cls, dtype, rank, order, memory_location):
36-
# if rank == 0:
37-
# return dtype
38-
# else:
39-
# return super().__new__(cls, dtype, rank, order)
4035
def __init__(self, dtype, rank, order, memory_location):
4136
assert isinstance(rank, int)
4237
assert order in (None, 'C', 'F')

pyccel/codegen/printing/ccode.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,8 +1315,7 @@ def get_declare_type(self, expr):
13151315
self.add_import(c_imports['ndarrays'])
13161316
dtype = 't_ndarray'
13171317
elif isinstance(expr.class_type, CudaArrayType):
1318-
self.add_import(c_imports['ndarrays'])
1319-
dtype = 't_ndarray'
1318+
dtype = 't_cuda_ndarray'
13201319

13211320
else:
13221321
errors.report(PYCCEL_RESTRICTION_TODO+' (rank>0)', symbol=expr, severity='fatal')

pyccel/codegen/printing/cucode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def _print_Allocate(self, expr):
163163
else:
164164
memory_location = 'managedMemory'
165165
self.add_import(c_imports['cuda_ndarrays'])
166-
self.add_import(c_imports['ndarrays'])
167166
alloc_code = f"{self._print(expr.variable)} = cuda_array_create({variable.rank}, shape_Assign, {dtype}, {is_view},{memory_location});\n"
168167
return f'{shape_Assign} {alloc_code}'
169168

pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ void host_memory(void** devPtr, size_t size)
1414
{
1515
cudaMallocHost(devPtr, size);
1616
}
17-
t_ndarray cuda_array_create(enum e_memory_locations location, int32_t nd, int64_t *shape,
17+
t_cuda_ndarray cuda_array_create(enum e_memory_locations location, int32_t nd, int64_t *shape,
1818
enum e_types type, bool is_view)
1919
{
20-
t_ndarray arr;
20+
t_cuda_ndarray arr;
2121
void (*fun_ptr_arr[])(void**, size_t) = {managed_memory, host_memory, device_memory};
2222

2323
arr.nd = nd;
@@ -61,7 +61,7 @@ t_ndarray cuda_array_create(enum e_memory_locations location, int32_t nd, int6
6161
return (arr);
6262
}
6363

64-
int32_t cuda_free_host(t_ndarray arr)
64+
int32_t cuda_free_host(t_cuda_ndarray arr)
6565
{
6666
if (arr.shape == NULL)
6767
return (0);
@@ -75,7 +75,7 @@ int32_t cuda_free_host(t_ndarray arr)
7575
}
7676

7777
__host__ __device__
78-
int32_t cuda_free(t_ndarray arr)
78+
int32_t cuda_free(t_cuda_ndarray arr)
7979
{
8080
if (arr.shape == NULL)
8181
return (0);
@@ -87,7 +87,7 @@ int32_t cuda_free(t_ndarray arr)
8787
}
8888

8989
__host__ __device__
90-
int32_t cuda_free_pointer(t_ndarray arr)
90+
int32_t cuda_free_pointer(t_cuda_ndarray arr)
9191
{
9292
if (arr.is_view == false || arr.shape == NULL)
9393
return (0);

pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.h

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,62 @@
33

44
# include <cuda_runtime.h>
55
# include <iostream>
6-
#include "../ndarrays/ndarrays.h"
76

8-
t_ndarray cuda_array_create(int32_t nd, int64_t *shape, enum e_types type, bool is_view ,
7+
typedef enum e_types
8+
{
9+
nd_bool = 0,
10+
nd_int8 = 1,
11+
nd_int16 = 3,
12+
nd_int32 = 5,
13+
nd_int64 = 7,
14+
nd_float = 11,
15+
nd_double = 12,
16+
nd_cfloat = 14,
17+
nd_cdouble = 15
18+
} t_types;
19+
20+
21+
enum e_memory_locations
22+
{
23+
managedMemory,
24+
allocateMemoryOnHost,
25+
allocateMemoryOnDevice
26+
};
27+
28+
typedef enum e_order
29+
{
30+
order_f,
31+
order_c,
32+
} t_order;
33+
34+
typedef struct s_cuda_ndarray
35+
{
36+
void *raw_data;
37+
/* number of dimensions */
38+
int32_t nd;
39+
/* shape 'size of each dimension' */
40+
int64_t *shape;
41+
/* strides 'number of elements to skip to get the next element' */
42+
int64_t *strides;
43+
/* type of the array elements */
44+
t_types type;
45+
/* type size of the array elements */
46+
int32_t type_size;
47+
/* number of element in the array */
48+
int32_t length;
49+
/* size of the array */
50+
int32_t buffer_size;
51+
/* True if the array does not own the data */
52+
bool is_view;
53+
/* stores the order of the array: order_f or order_c */
54+
t_order order;
55+
} t_cuda_ndarray;
56+
57+
58+
t_cuda_ndarray cuda_array_create(int32_t nd, int64_t *shape, enum e_types type, bool is_view ,
959
enum e_memory_locations location);
10-
int32_t cuda_free_host(t_ndarray arr);
60+
int32_t cuda_free_host(t_cuda_ndarray arr);
61+
1162

1263

1364
using namespace std;

0 commit comments

Comments
 (0)