|
13 | 13 | from pyccel.ast.bind_c import CLocFunc, BindCModule, BindCVariable
|
14 | 14 | from pyccel.ast.bind_c import BindCArrayVariable, BindCClassDef, DeallocatePointer
|
15 | 15 | from pyccel.ast.bind_c import BindCClassProperty
|
| 16 | +from pyccel.ast.builtins import VariableIterator |
16 | 17 | from pyccel.ast.core import Assign, FunctionCall, FunctionCallArgument
|
17 | 18 | from pyccel.ast.core import Allocate, EmptyNode, FunctionAddress
|
18 | 19 | from pyccel.ast.core import If, IfSection, Import, Interface, FunctionDefArgument
|
19 | 20 | from pyccel.ast.core import AsName, Module, AliasAssign, FunctionDefResult
|
| 21 | +from pyccel.ast.core import For |
20 | 22 | from pyccel.ast.datatypes import CustomDataType, FixedSizeNumericType
|
21 | 23 | from pyccel.ast.datatypes import HomogeneousTupleType, TupleType
|
| 24 | +from pyccel.ast.datatypes import HomogeneousSetType, PythonNativeInt |
| 25 | +from pyccel.ast.datatypes import HomogeneousListType |
22 | 26 | from pyccel.ast.internals import Slice
|
23 | 27 | from pyccel.ast.literals import LiteralInteger, Nil, LiteralTrue
|
24 | 28 | from pyccel.ast.numpytypes import NumpyNDArrayType
|
25 |
| -from pyccel.ast.operators import PyccelIsNot, PyccelMul |
| 29 | +from pyccel.ast.operators import PyccelIsNot, PyccelMul, PyccelAdd |
26 | 30 | from pyccel.ast.variable import Variable, IndexedElement, DottedVariable
|
27 | 31 | from pyccel.ast.numpyext import NumpyNDArrayType
|
28 | 32 | from pyccel.errors.errors import Errors
|
@@ -419,14 +423,33 @@ def _wrap_FunctionDefResult(self, expr):
|
419 | 423 |
|
420 | 424 | if not (var.is_alias or wrap_dotted):
|
421 | 425 | # Create an array variable which can be passed to CLocFunc
|
422 |
| - ptr_var = var.clone(scope.get_new_name(name+'_ptr'), |
| 426 | + ptr_var = Variable(NumpyNDArrayType(var.dtype, var.rank, var.order), scope.get_new_name(name+'_ptr'), |
423 | 427 | memory_handling='alias')
|
424 | 428 | scope.insert_variable(ptr_var)
|
425 | 429 |
|
426 | 430 | # Define the additional steps necessary to define and fill ptr_var
|
427 | 431 | alloc = Allocate(ptr_var, shape=result.shape, status='unallocated')
|
428 |
| - copy = Assign(ptr_var, local_var) |
429 |
| - self._additional_exprs.extend([alloc, copy]) |
| 432 | + if isinstance(local_var.class_type, (NumpyNDArrayType, HomogeneousTupleType, CustomDataType)): |
| 433 | + copy = Assign(ptr_var, local_var) |
| 434 | + self._additional_exprs.extend([alloc, copy]) |
| 435 | + elif isinstance(local_var.class_type, (HomogeneousSetType, HomogeneousListType)): |
| 436 | + iterator = VariableIterator(local_var) |
| 437 | + elem = Variable(var.class_type.element_type, self.scope.get_new_name()) |
| 438 | + idx = Variable(PythonNativeInt(), self.scope.get_new_name()) |
| 439 | + self.scope.insert_variable(elem) |
| 440 | + assign = Assign(idx, LiteralInteger(0)) |
| 441 | + for_scope = self.scope.create_new_loop_scope() |
| 442 | + for_body = [Assign(IndexedElement(ptr_var, idx), elem)] |
| 443 | + if isinstance(local_var.class_type, HomogeneousSetType): |
| 444 | + self.scope.insert_variable(idx) |
| 445 | + for_body.append(Assign(idx, PyccelAdd(idx, LiteralInteger(1)))) |
| 446 | + else: |
| 447 | + iterator.set_loop_counter(idx) |
| 448 | + fill_for = For((elem,), iterator, for_body, scope = for_scope) |
| 449 | + self._additional_exprs.extend([alloc, assign, fill_for]) |
| 450 | + else: |
| 451 | + raise errors.report(f"Don't know how to return an object of type {local_var.class_type} to C code.", |
| 452 | + severity='fatal', symbol = var) |
430 | 453 | else:
|
431 | 454 | ptr_var = var
|
432 | 455 |
|
|
0 commit comments