Skip to content

Commit 2ec91c4

Browse files
authored
Add Fortran support for dict initialisation (pyccel#2005)
Add Fortran support for dict initialisation in a similar way to lists. Fixes pyccel#1944
1 parent 0f25459 commit 2ec91c4

File tree

4 files changed

+126
-11
lines changed

4 files changed

+126
-11
lines changed

pyccel/ast/builtins.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,9 @@ def __repr__(self):
908908
args = ', '.join(f'{repr(k)}: {repr(v)}' for k,v in self)
909909
return f'PythonDict({args})'
910910

911+
def __len__(self):
912+
return len(self._keys)
913+
911914
@property
912915
def keys(self):
913916
"""

pyccel/ast/low_level_tools.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .datatypes import PyccelType
1212

1313
__all__ = ('IteratorType',
14+
'PairType',
1415
'MacroDefinition')
1516

1617
#------------------------------------------------------------------------------
@@ -40,6 +41,51 @@ def iterable_type(self):
4041
"""
4142
return self._iterable_type
4243

44+
#------------------------------------------------------------------------------
45+
class PairType(PyccelType, metaclass=ArgumentSingleton):
46+
"""
47+
The type of an element of a dictionary type.
48+
49+
The type of an element of a dictionary type.
50+
51+
Parameters
52+
----------
53+
key_type : PyccelType
54+
The type of the keys of the homogeneous dictionary.
55+
value_type : PyccelType
56+
The type of the values of the homogeneous dictionary.
57+
"""
58+
__slots__ = ('_key_type', '_value_type')
59+
_name = 'pair'
60+
_container_rank = 0
61+
_order = None
62+
63+
def __init__(self, key_type, value_type):
64+
self._key_type = key_type
65+
self._value_type = value_type
66+
super().__init__()
67+
68+
@property
69+
def key_type(self):
70+
"""
71+
The type of the keys of the object.
72+
73+
The type of the keys of the object.
74+
"""
75+
return self._key_type
76+
77+
@property
78+
def value_type(self):
79+
"""
80+
The type of the values of the object.
81+
82+
The type of the values of the object.
83+
"""
84+
return self._value_type
85+
86+
def __str__(self):
87+
return f'pair[{self._key_type}, {self._value_type}]'
88+
4389
#------------------------------------------------------------------------------
4490
class MacroDefinition(PyccelAstNode):
4591
"""

pyccel/codegen/printing/fcode.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pyccel.ast.literals import LiteralTrue, LiteralFalse, LiteralString
5252
from pyccel.ast.literals import Nil
5353

54-
from pyccel.ast.low_level_tools import MacroDefinition, IteratorType
54+
from pyccel.ast.low_level_tools import MacroDefinition, IteratorType, PairType
5555

5656
from pyccel.ast.mathext import math_constants
5757

@@ -584,6 +584,37 @@ def _build_gFTL_module(self, expr_type):
584584
raise NotImplementedError("Support for sets of types which define their own < operator is not yet implemented")
585585
imports_and_macros.append(MacroDefinition('Set', expr_type))
586586
imports_and_macros.append(MacroDefinition('SetIterator', IteratorType(expr_type)))
587+
elif isinstance(expr_type, DictType):
588+
include = Import(LiteralString('map/template.inc'), Module('_', (), ()))
589+
key_type = expr_type.key_type
590+
value_type = expr_type.value_type
591+
imports_and_macros = []
592+
if isinstance(key_type, FixedSizeNumericType):
593+
tmpVar_x = Variable(key_type, 'x')
594+
tmpVar_y = Variable(key_type, 'y')
595+
if isinstance(key_type.primitive_type, PrimitiveComplexType):
596+
complex_tool_import = Import('pyc_tools_f90', Module('pyc_tools_f90',(),()))
597+
self.add_import(complex_tool_import)
598+
imports_and_macros.append(complex_tool_import)
599+
compare_func = FunctionDef('complex_comparison',
600+
[FunctionDefArgument(tmpVar_x), FunctionDefArgument(tmpVar_y)],
601+
[FunctionDefResult(Variable(PythonNativeBool(), 'c'))], [])
602+
lt_def = FunctionCall(compare_func, [tmpVar_x, tmpVar_y])
603+
else:
604+
lt_def = PyccelAssociativeParenthesis(PyccelLt(tmpVar_x, tmpVar_y))
605+
imports_and_macros.extend([MacroDefinition('Key', key_type.primitive_type),
606+
MacroDefinition('Key_KINDLEN(context)', KindSpecification(key_type)),
607+
MacroDefinition('Key_LT(x,y)', lt_def)])
608+
else:
609+
raise NotImplementedError("Support for dicts whose keys define their own < operator is not yet implemented")
610+
if isinstance(value_type, FixedSizeNumericType):
611+
imports_and_macros.extend([MacroDefinition('T', value_type.primitive_type),
612+
MacroDefinition('T_KINDLEN(context)', KindSpecification(value_type))])
613+
else:
614+
raise NotImplementedError(f"Support for dictionary values of type {value_type} not yet implemented")
615+
imports_and_macros.append(MacroDefinition('Pair', PairType(key_type, value_type)))
616+
imports_and_macros.append(MacroDefinition('Map', expr_type))
617+
imports_and_macros.append(MacroDefinition('MapIterator', IteratorType(expr_type)))
587618
else:
588619
raise NotImplementedError(f"Unkown gFTL import for type {expr_type}")
589620

@@ -1130,6 +1161,24 @@ def _print_PythonSet(self, expr):
11301161
set_type = self._print(expr.class_type)
11311162
return f'{set_type}({list_arg})'
11321163

1164+
def _print_PythonDict(self, expr):
1165+
if len(expr) == 0:
1166+
list_arg = ''
1167+
assign = expr.get_direct_user_nodes(lambda a : isinstance(a, Assign))
1168+
if assign:
1169+
dict_type = self._print(assign[0].lhs.class_type)
1170+
else:
1171+
raise errors.report("Can't use an empty dict without assigning it to a variable as the type cannot be deduced",
1172+
severity='fatal', symbol=expr)
1173+
1174+
else:
1175+
class_type = expr.class_type
1176+
pair_type = self._print(PairType(class_type.key_type, class_type.value_type))
1177+
args = ', '.join(f'{pair_type}({self._print(k)}, {self._print(v)})' for k,v in expr)
1178+
list_arg = f'[{args}]'
1179+
dict_type = self._print(class_type)
1180+
return f'{dict_type}({list_arg})'
1181+
11331182
def _print_InhomogeneousTupleVariable(self, expr):
11341183
fs = ', '.join(self._print(f) for f in expr)
11351184
return '[{0}]'.format(fs)
@@ -1988,7 +2037,7 @@ def _print_Allocate(self, expr):
19882037

19892038
return code
19902039

1991-
elif isinstance(class_type, HomogeneousContainerType):
2040+
elif isinstance(class_type, (HomogeneousContainerType, DictType)):
19922041
return ''
19932042

19942043
else:
@@ -2006,7 +2055,7 @@ def _print_Deallocate(self, expr):
20062055
Pyccel_del_args = [FunctionCallArgument(var)]
20072056
return self._print(FunctionCall(Pyccel__del, Pyccel_del_args))
20082057

2009-
if var.is_alias or isinstance(class_type, (HomogeneousListType, HomogeneousSetType)):
2058+
if var.is_alias or isinstance(class_type, (HomogeneousListType, HomogeneousSetType, DictType)):
20102059
return ''
20112060
elif isinstance(class_type, (NumpyNDArrayType, HomogeneousTupleType, StringType)):
20122061
var_code = self._print(var)
@@ -2053,9 +2102,16 @@ def _print_HomogeneousListType(self, expr):
20532102
def _print_HomogeneousSetType(self, expr):
20542103
return 'Set_'+self._print(expr.element_type)
20552104

2105+
def _print_PairType(self, expr):
2106+
return 'Pair_'+self._print(expr.key_type)+'__'+self._print(expr.value_type)
2107+
2108+
def _print_DictType(self, expr):
2109+
return 'Map_'+self._print(expr.key_type)+'__'+self._print(expr.value_type)
2110+
20562111
def _print_IteratorType(self, expr):
20572112
iterable_type = self._print(expr.iterable_type)
20582113
return f"{iterable_type}_Iterator"
2114+
20592115
def _print_DataType(self, expr):
20602116
return self._print(expr.name)
20612117

tests/epyccel/test_epyccel_variable_annotations.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -301,22 +301,32 @@ def homogeneous_list_annotation():
301301
assert epyc_homogeneous_list_annotation() == homogeneous_list_annotation()
302302
assert isinstance(epyc_homogeneous_list_annotation(), type(homogeneous_list_annotation()))
303303

304-
def test_dict_int_float(stc_language):
304+
def test_dict_int_float(language):
305305
def dict_int_float():
306306
# Not valid in Python 3.8
307307
a : dict[int, float] #pylint: disable=unsubscriptable-object,unused-variable
308308
a = {1:1.0, 2:2.0}
309+
return len(a)
309310

310-
epyc_dict_int_float = epyccel(dict_int_float, language = stc_language)
311-
epyc_dict_int_float()
312-
dict_int_float()
311+
epyc_dict_int_float = epyccel(dict_int_float, language = language)
312+
assert epyc_dict_int_float() == dict_int_float()
313313

314-
def test_dict_empty_init(stc_language):
314+
def test_dict_empty_init(language):
315315
def dict_empty_init():
316316
# Not valid in Python 3.8
317317
a : dict[int, float] #pylint: disable=unsubscriptable-object,unused-variable
318318
a = {}
319+
return len(a)
320+
321+
epyc_dict_empty_init = epyccel(dict_empty_init, language = language)
322+
assert epyc_dict_empty_init() == dict_empty_init()
323+
324+
def test_dict_complex_float(language):
325+
def dict_int_float():
326+
# Not valid in Python 3.8
327+
a : dict[complex, float] #pylint: disable=unsubscriptable-object,unused-variable
328+
a = {1j:1.0, -1j:2.0}
329+
return len(a)
319330

320-
epyc_dict_empty_init = epyccel(dict_empty_init, language = stc_language)
321-
epyc_dict_empty_init()
322-
dict_empty_init()
331+
epyc_dict_int_float = epyccel(dict_int_float, language = language)
332+
assert epyc_dict_int_float() == dict_int_float()

0 commit comments

Comments
 (0)