15
15
from pyccel .ast .literals import Nil
16
16
17
17
from pyccel .errors .errors import Errors
18
- from pyccel .ast .core import Allocate , Deallocate
19
- from pyccel .ast .numpytypes import NumpyInt64Type
20
18
from pyccel .ast .cudatypes import CudaArrayType
21
19
from pyccel .ast .datatypes import HomogeneousContainerType
22
- from pyccel .ast .numpytypes import NumpyNDArrayType , numpy_precision_map
20
+ from pyccel .ast .numpytypes import numpy_precision_map
21
+ from pyccel .ast .cudaext import CudaFull
23
22
24
23
25
24
@@ -147,14 +146,15 @@ def _print_ModuleHeader(self, expr):
147
146
"#endif // {name.upper()}_H\n " ))
148
147
def _print_Allocate (self , expr ):
149
148
variable = expr .variable
149
+ if not isinstance (variable .class_type , CudaArrayType ):
150
+ return super ()._print_Allocate (expr )
150
151
shape = ", " .join (self ._print (i ) for i in expr .shape )
151
152
if isinstance (variable .class_type , CudaArrayType ):
152
153
dtype = self .find_in_ndarray_type_registry (variable .dtype )
153
154
elif isinstance (variable .class_type , HomogeneousContainerType ):
154
155
dtype = self .find_in_ndarray_type_registry (numpy_precision_map [(variable .dtype .primitive_type , variable .dtype .precision )])
155
156
else :
156
157
raise NotImplementedError (f"Don't know how to index { variable .class_type } type" )
157
- shape_dtype = self .get_c_type (NumpyInt64Type ())
158
158
shape_Assign = "int64_t shape_Assign [] = {" + shape + "};\n "
159
159
is_view = 'false' if variable .on_heap else 'true'
160
160
memory_location = expr .variable .memory_location
@@ -169,8 +169,19 @@ def _print_Allocate(self, expr):
169
169
def _print_Deallocate (self , expr ):
170
170
var_code = self ._print (expr .variable )
171
171
172
+ if not isinstance (expr .variable .class_type , CudaArrayType ):
173
+ return super ()._print_Deallocate (expr )
174
+
172
175
if expr .variable .memory_location == 'host' :
173
176
return f"cuda_free_host({ var_code } );\n "
174
177
else :
175
178
return f"cuda_free({ var_code } );\n "
176
179
180
+ def _print_Assign (self , expr ):
181
+ rhs = expr .rhs
182
+ if not isinstance (rhs .class_type , CudaArrayType ):
183
+ return super ()._print_Assign (expr )
184
+ if (isinstance (rhs , (CudaFull ))):
185
+ # TODO add support for CudaFull
186
+ return " \n "
187
+
0 commit comments