Skip to content

Commit 190c5a2

Browse files
committed
work in progress
1 parent 3afad1b commit 190c5a2

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

pyccel/codegen/printing/ccode.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from pyccel.ast.numpytypes import NumpyFloat32Type, NumpyFloat64Type, NumpyComplex64Type, NumpyComplex128Type
4747
from pyccel.ast.numpytypes import NumpyNDArrayType, numpy_precision_map
4848
from pyccel.ast.cudatypes import CudaArrayType
49+
from pyccel.ast.cudaext import CudaFull
4950

5051
from pyccel.ast.utilities import expand_to_loops
5152

@@ -59,6 +60,7 @@
5960

6061
from pyccel.codegen.printing.codeprinter import CodePrinter
6162

63+
6264
from pyccel.errors.errors import Errors
6365
from pyccel.errors.messages import (PYCCEL_RESTRICTION_TODO, INCOMPATIBLE_TYPEVAR_TO_FUNC,
6466
PYCCEL_RESTRICTION_IS_ISNOT, UNSUPPORTED_ARRAY_RANK)
@@ -2181,6 +2183,9 @@ def _print_Assign(self, expr):
21812183
# Inhomogenous tuples are unravelled and therefore do not exist in the c printer
21822184
if isinstance(rhs, (NumpyArray, PythonTuple)):
21832185
return prefix_code+self.copy_NumpyArray_Data(expr)
2186+
if(isinstance(rhs, (CudaFull))):
2187+
# TODO add support for CudaFull
2188+
return " \n"
21842189
if isinstance(rhs, (NumpyFull)):
21852190
return prefix_code+self.arrayFill(expr)
21862191
lhs = self._print(expr.lhs)

pyccel/codegen/printing/cucode.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ def _print_CudaSynchronize(self, expr):
116116
return 'cudaDeviceSynchronize();\n'
117117

118118
def _print_CudaEmpty(self, expr):
119-
print(expr)
120-
return 'cudaDeviceSynchronize();\n'
119+
return 'cuda_array_create(1, (int64_t[]){INT64_C(10)}, nd_double, false,allocateMemoryOnHost);\n'
121120
def _print_ModuleHeader(self, expr):
122121
self.set_scope(expr.module.scope)
123122
self._in_header = True
@@ -158,7 +157,7 @@ def _print_Allocate(self, expr):
158157
else:
159158
raise NotImplementedError(f"Don't know how to index {variable.class_type} type")
160159
shape_dtype = self.get_c_type(NumpyInt64Type())
161-
shape_Assign = "("+ shape_dtype +"[]){" + shape + "}"
160+
shape_Assign = "int64_t shape_Assign [] = {" + shape + "};\n"
162161
is_view = 'false' if variable.on_heap else 'true'
163162
memory_location = expr.variable.memory_location
164163
if memory_location in ('device', 'host'):
@@ -167,8 +166,8 @@ def _print_Allocate(self, expr):
167166
memory_location = 'managedMemory'
168167
self.add_import(c_imports['cuda_ndarrays'])
169168
self.add_import(c_imports['ndarrays'])
170-
alloc_code = f"{self._print(expr.variable)} = cuda_array_create({variable.rank}, {shape_Assign}, {dtype}, {is_view},{memory_location});\n"
171-
return f'{alloc_code}'
169+
alloc_code = f"{self._print(expr.variable)} = cuda_array_create({variable.rank}, shape_Assign, {dtype}, {is_view},{memory_location});\n"
170+
return f'{shape_Assign} {alloc_code}'
172171

173172
def _print_Deallocate(self, expr):
174173
var_code = self._print(expr.variable)

pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
# include <iostream>
66
#include "../ndarrays/ndarrays.h"
77

8+
t_ndarray cuda_array_create(int32_t nd, int64_t *shape, enum e_types type, bool is_view ,
9+
enum e_memory_locations location);
10+
int32_t cuda_free_host(t_ndarray arr);
11+
12+
813
using namespace std;
914

1015

0 commit comments

Comments
 (0)