Skip to content

Commit 96c3f29

Browse files
committed
work in progress
1 parent 8286a89 commit 96c3f29

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

pyccel/ast/numpyext.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,6 @@ def __init__(self, *args, class_type, init_dtype = None):
626626
assert isinstance(class_type, NumpyNDArrayType)
627627
self._init_dtype = init_dtype
628628
self._class_type = class_type # pylint: disable=no-member
629-
print(*args)
630629
super().__init__(*args)
631630

632631
@property

pyccel/codegen/printing/ccode.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,9 +2182,6 @@ def _print_Assign(self, expr):
21822182
# Inhomogenous tuples are unravelled and therefore do not exist in the c printer
21832183
if isinstance(rhs, (NumpyArray, PythonTuple)):
21842184
return prefix_code+self.copy_NumpyArray_Data(expr)
2185-
if(isinstance(rhs, (CudaFull))):
2186-
# TODO add support for CudaFull
2187-
return " \n"
21882185
if isinstance(rhs, (NumpyFull)):
21892186
return prefix_code+self.arrayFill(expr)
21902187
lhs = self._print(expr.lhs)

pyccel/codegen/printing/cucode.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
from pyccel.ast.literals import Nil
1616

1717
from pyccel.errors.errors import Errors
18-
from pyccel.ast.core import Allocate, Deallocate
19-
from pyccel.ast.numpytypes import NumpyInt64Type
2018
from pyccel.ast.cudatypes import CudaArrayType
2119
from pyccel.ast.datatypes import HomogeneousContainerType
22-
from pyccel.ast.numpytypes import NumpyNDArrayType, numpy_precision_map
20+
from pyccel.ast.numpytypes import numpy_precision_map
21+
from pyccel.ast.cudaext import CudaFull
2322

2423

2524

@@ -147,14 +146,15 @@ def _print_ModuleHeader(self, expr):
147146
"#endif // {name.upper()}_H\n"))
148147
def _print_Allocate(self, expr):
149148
variable = expr.variable
149+
if not isinstance(variable.class_type, CudaArrayType):
150+
return super()._print_Allocate(expr)
150151
shape = ", ".join(self._print(i) for i in expr.shape)
151152
if isinstance(variable.class_type, CudaArrayType):
152153
dtype = self.find_in_ndarray_type_registry(variable.dtype)
153154
elif isinstance(variable.class_type, HomogeneousContainerType):
154155
dtype = self.find_in_ndarray_type_registry(numpy_precision_map[(variable.dtype.primitive_type, variable.dtype.precision)])
155156
else:
156157
raise NotImplementedError(f"Don't know how to index {variable.class_type} type")
157-
shape_dtype = self.get_c_type(NumpyInt64Type())
158158
shape_Assign = "int64_t shape_Assign [] = {" + shape + "};\n"
159159
is_view = 'false' if variable.on_heap else 'true'
160160
memory_location = expr.variable.memory_location
@@ -169,8 +169,19 @@ def _print_Allocate(self, expr):
169169
def _print_Deallocate(self, expr):
170170
var_code = self._print(expr.variable)
171171

172+
if not isinstance(expr.variable.class_type, CudaArrayType):
173+
return super()._print_Deallocate(expr)
174+
172175
if expr.variable.memory_location == 'host':
173176
return f"cuda_free_host({var_code});\n"
174177
else:
175178
return f"cuda_free({var_code});\n"
176179

180+
def _print_Assign(self, expr):
181+
rhs = expr.rhs
182+
if not isinstance(rhs.class_type, CudaArrayType):
183+
return super()._print_Assign(expr)
184+
if(isinstance(rhs, (CudaFull))):
185+
# TODO add support for CudaFull
186+
return " \n"
187+

0 commit comments

Comments
 (0)