|
11 | 11 |
|
12 | 12 | from pyccel.ast.builtins import PythonRange, PythonComplex
|
13 | 13 | from pyccel.ast.builtins import PythonPrint, PythonType
|
14 |
| -from pyccel.ast.builtins import PythonList, PythonTuple, PythonSet |
| 14 | +from pyccel.ast.builtins import PythonList, PythonTuple, PythonSet, PythonDict |
15 | 15 |
|
16 | 16 | from pyccel.ast.core import Declare, For, CodeBlock
|
17 | 17 | from pyccel.ast.core import FuncAddressDeclare, FunctionCall, FunctionCallArgument
|
|
29 | 29 | from pyccel.ast.datatypes import TupleType, FixedSizeNumericType
|
30 | 30 | from pyccel.ast.datatypes import CustomDataType, StringType, HomogeneousTupleType, HomogeneousListType, HomogeneousSetType
|
31 | 31 | from pyccel.ast.datatypes import PrimitiveBooleanType, PrimitiveIntegerType, PrimitiveFloatingPointType, PrimitiveComplexType
|
32 |
| -from pyccel.ast.datatypes import HomogeneousContainerType |
| 32 | +from pyccel.ast.datatypes import HomogeneousContainerType, DictType |
33 | 33 |
|
34 | 34 | from pyccel.ast.internals import Slice, PrecomputedCode, PyccelArrayShapeElement
|
35 | 35 |
|
@@ -678,8 +678,13 @@ def init_stc_container(self, expr, assignment_var):
|
678 | 678 | The generated C code for the container initialization.
|
679 | 679 | """
|
680 | 680 |
|
681 |
| - dtype = self.get_c_type(assignment_var.lhs.class_type) |
682 |
| - keyraw = '{' + ', '.join(self._print(a) for a in expr.args) + '}' |
| 681 | + class_type = assignment_var.lhs.class_type |
| 682 | + dtype = self.get_c_type(class_type) |
| 683 | + if isinstance(expr, PythonDict): |
| 684 | + dict_item_strs = [(self._print(k), self._print(v)) for k,v in zip(expr.keys, expr.values)] |
| 685 | + keyraw = '{' + ', '.join(f'{{{k}, {v}}}' for k,v in dict_item_strs) + '}' |
| 686 | + else: |
| 687 | + keyraw = '{' + ', '.join(self._print(a) for a in expr.args) + '}' |
683 | 688 | container_name = self._print(assignment_var.lhs)
|
684 | 689 | init = f'{container_name} = c_init({dtype}, {keyraw});\n'
|
685 | 690 | return init
|
@@ -1013,15 +1018,22 @@ def _print_Import(self, expr):
|
1013 | 1018 | for t in expr.target:
|
1014 | 1019 | dtype = t.object.class_type
|
1015 | 1020 | container_type = t.target
|
1016 |
| - container_key = self.get_c_type(dtype.element_type) |
| 1021 | + if isinstance(dtype, DictType): |
| 1022 | + container_key_key = self.get_c_type(dtype.key_type) |
| 1023 | + container_val_key = self.get_c_type(dtype.value_type) |
| 1024 | + container_key = f'{container_key_key}_{container_val_key}' |
| 1025 | + element_decl = f'#define i_key {container_key_key}\n#define i_val {container_val_key}\n' |
| 1026 | + else: |
| 1027 | + container_key = self.get_c_type(dtype.element_type) |
| 1028 | + element_decl = f'#define i_key {container_key}\n' |
1017 | 1029 | header_guard_prefix = import_header_guard_prefix.get(source, '')
|
1018 | 1030 | header_guard = f'{header_guard_prefix}_{container_type.upper()}'
|
1019 |
| - code += (f'#ifndef {header_guard}\n' |
1020 |
| - f'#define {header_guard}\n' |
1021 |
| - f'#define i_type {container_type}\n' |
1022 |
| - f'#define i_key {container_key}\n' |
1023 |
| - f'#include <{source}.h>\n' |
1024 |
| - f'#endif // {header_guard}\n\n') |
| 1031 | + code += ''.join((f'#ifndef {header_guard}\n', |
| 1032 | + f'#define {header_guard}\n', |
| 1033 | + f'#define i_type {container_type}\n', |
| 1034 | + element_decl, |
| 1035 | + f'#include <{source}.h>\n', |
| 1036 | + f'#endif // {header_guard}\n\n')) |
1025 | 1037 | return code
|
1026 | 1038 | # Get with a default value is not used here as it is
|
1027 | 1039 | # slower and on most occasions the import will not be in the
|
@@ -1245,6 +1257,13 @@ def get_c_type(self, dtype):
|
1245 | 1257 | i_type = f'{container_type}_{element_type}'
|
1246 | 1258 | self.add_import(Import(f'stc/{container_type}', AsName(VariableTypeAnnotation(dtype), i_type)))
|
1247 | 1259 | return i_type
|
| 1260 | + elif isinstance(dtype, DictType): |
| 1261 | + container_type = 'hmap' |
| 1262 | + key_type = self.get_c_type(dtype.key_type).replace(' ', '_') |
| 1263 | + val_type = self.get_c_type(dtype.value_type).replace(' ', '_') |
| 1264 | + i_type = f'{container_type}_{key_type}_{val_type}' |
| 1265 | + self.add_import(Import(f'stc/{container_type}', AsName(VariableTypeAnnotation(dtype), i_type))) |
| 1266 | + return i_type |
1248 | 1267 | else:
|
1249 | 1268 | key = dtype
|
1250 | 1269 |
|
@@ -1320,7 +1339,7 @@ def get_declare_type(self, expr):
|
1320 | 1339 | rank = expr.rank
|
1321 | 1340 |
|
1322 | 1341 | if rank > 0:
|
1323 |
| - if isinstance(expr.class_type, (HomogeneousSetType, HomogeneousListType)): |
| 1342 | + if isinstance(expr.class_type, (HomogeneousSetType, HomogeneousListType, DictType)): |
1324 | 1343 | dtype = self.get_c_type(expr.class_type)
|
1325 | 1344 | return dtype
|
1326 | 1345 | if isinstance(expr.class_type,(HomogeneousTupleType, NumpyNDArrayType)):
|
@@ -1610,7 +1629,7 @@ def _print_PyccelArrayShapeElement(self, expr):
|
1610 | 1629 | def _print_Allocate(self, expr):
|
1611 | 1630 | free_code = ''
|
1612 | 1631 | variable = expr.variable
|
1613 |
| - if isinstance(variable.class_type, (HomogeneousListType, HomogeneousSetType)): |
| 1632 | + if isinstance(variable.class_type, (HomogeneousListType, HomogeneousSetType, DictType)): |
1614 | 1633 | return ''
|
1615 | 1634 | if variable.rank > 0:
|
1616 | 1635 | #free the array if its already allocated and checking if its not null if the status is unknown
|
@@ -1646,7 +1665,7 @@ def _print_Allocate(self, expr):
|
1646 | 1665 | raise NotImplementedError(f"Allocate not implemented for {variable}")
|
1647 | 1666 |
|
1648 | 1667 | def _print_Deallocate(self, expr):
|
1649 |
| - if isinstance(expr.variable.class_type, (HomogeneousListType, HomogeneousSetType)): |
| 1668 | + if isinstance(expr.variable.class_type, (HomogeneousListType, HomogeneousSetType, DictType)): |
1650 | 1669 | variable_address = self._print(ObjectAddress(expr.variable))
|
1651 | 1670 | container_type = self.get_c_type(expr.variable.class_type)
|
1652 | 1671 | return f'{container_type}_drop({variable_address});\n'
|
@@ -2196,7 +2215,7 @@ def _print_Assign(self, expr):
|
2196 | 2215 | if isinstance(rhs, (NumpyFull)):
|
2197 | 2216 | return prefix_code+self.arrayFill(expr)
|
2198 | 2217 | lhs = self._print(expr.lhs)
|
2199 |
| - if isinstance(rhs, (PythonList, PythonSet)): |
| 2218 | + if isinstance(rhs, (PythonList, PythonSet, PythonDict)): |
2200 | 2219 | return prefix_code+self.init_stc_container(rhs, expr)
|
2201 | 2220 | rhs = self._print(expr.rhs)
|
2202 | 2221 | return prefix_code+'{} = {};\n'.format(lhs, rhs)
|
|
0 commit comments