Skip to content

Commit dfe48c7

Browse files
authored
[DICT] Add Python support for dict initialisation (pyccel#1896)
Add support for the initialisation of a dictionary via a call to `{}`. Syntactic, semantic and Python printing support is added. The `DictType` datatype is expanded to match the expected description from the docs. Fixes pyccel#1895
1 parent 267662e commit dfe48c7

File tree

12 files changed

+209
-25
lines changed

12 files changed

+209
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ All notable changes to this project will be documented in this file.
2323
- #1659 : Add the appropriate C language equivalent for declaring a Python `set` container using the STC library.
2424
- #1893 : Add Python support for set initialisation with `set()`.
2525
- #1877 : Add C Support for set method `pop()`.
26+
- #1895 : Add Python support for dict initialisation with `{}`.
2627
- \[INTERNALS\] Added `container_rank` property to `ast.datatypes.PyccelType` objects.
2728
- \[DEVELOPER\] Added an improved traceback to the developer-mode errors for errors in function calls.
2829

ci_tools/check_pylint_commands.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
re.compile('tests/codegen/fcode/scripts/precision.py'):['unused-variable'],
2121
re.compile('tests/semantic/scripts/expressions.py'):['unused-variable'],
2222
re.compile('tests/semantic/scripts/calls.py'):['unused-variable'],
23-
re.compile('tests/pyccel/project_class_imports/.*'):['relative-beyond-top-level'] # ignore Codacy bad pylint call
23+
re.compile('tests/pyccel/project_class_imports/.*'):['relative-beyond-top-level'], # ignore Codacy bad pylint call
24+
re.compile('tests/errors/syntax_errors/import_star.py'):['wildcard-import']
2425
}
2526

2627
def run_pylint(file, flag, messages):

pyccel/ast/builtins.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .datatypes import HomogeneousTupleType, InhomogeneousTupleType
2121
from .datatypes import HomogeneousListType, HomogeneousContainerType
2222
from .datatypes import FixedSizeNumericType, HomogeneousSetType, SymbolicType
23+
from .datatypes import DictType
2324
from .internals import PyccelFunction, Slice, PyccelArrayShapeElement
2425
from .literals import LiteralInteger, LiteralFloat, LiteralComplex, Nil
2526
from .literals import Literal, LiteralImaginaryUnit, convert_to_literal
@@ -37,6 +38,7 @@
3738
'PythonComplex',
3839
'PythonComplexProperty',
3940
'PythonConjugate',
41+
'PythonDict',
4042
'PythonEnumerate',
4143
'PythonFloat',
4244
'PythonImag',
@@ -805,6 +807,81 @@ def __init__(self, copied_obj):
805807
self._shape = copied_obj.shape
806808
super().__init__(copied_obj)
807809

810+
#==============================================================================
811+
class PythonDict(PyccelFunction):
812+
"""
813+
Class representing a call to Python's `{}` function.
814+
815+
Class representing a call to Python's `{}` function which generates a
816+
literal Python dict. This operator does not handle `**a` expressions.
817+
818+
Parameters
819+
----------
820+
keys : iterable[TypedAstNode]
821+
The keys of the new dictionary.
822+
values : iterable[TypedAstNode]
823+
The values of the new dictionary.
824+
"""
825+
__slots__ = ('_keys', '_values', '_shape', '_class_type')
826+
_attribute_nodes = ('_keys', '_values')
827+
_rank = 1
828+
829+
def __init__(self, keys, values):
830+
self._keys = keys
831+
self._values = values
832+
super().__init__()
833+
if pyccel_stage == 'syntactic':
834+
return
835+
elif len(keys) != len(values):
836+
raise TypeError("Unpacking values in a dictionary is not yet supported.")
837+
elif len(keys) == 0:
838+
self._shape = (LiteralInteger(0),)
839+
self._class_type = DictType(GenericType(), GenericType())
840+
return
841+
842+
key0 = keys[0]
843+
val0 = values[0]
844+
homogeneous_keys = all(k.class_type is not GenericType() for k in keys) and \
845+
all(key0.class_type == k.class_type for k in keys[1:])
846+
homogeneous_vals = all(v.class_type is not GenericType() for v in values) and \
847+
all(val0.class_type == v.class_type for v in values[1:])
848+
849+
if homogeneous_keys and homogeneous_vals:
850+
self._class_type = DictType(key0.class_type, val0.class_type)
851+
852+
self._shape = (LiteralInteger(len(keys)), )
853+
else:
854+
raise TypeError("Can't create an inhomogeneous dict")
855+
856+
def __iter__(self):
857+
return zip(self._keys, self._values)
858+
859+
def __str__(self):
860+
args = ', '.join(f'{k}: {v}' for k,v in self)
861+
return f'{{{args}}}'
862+
863+
def __repr__(self):
864+
args = ', '.join(f'{repr(k)}: {repr(v)}' for k,v in self)
865+
return f'PythonDict({args})'
866+
867+
@property
868+
def keys(self):
869+
"""
870+
The keys of the new dictionary.
871+
872+
The keys of the new dictionary.
873+
"""
874+
return self._keys
875+
876+
@property
877+
def values(self):
878+
"""
879+
The values of the new dictionary.
880+
881+
The values of the new dictionary.
882+
"""
883+
return self._values
884+
808885
#==============================================================================
809886
class PythonMap(PyccelFunction):
810887
"""

pyccel/ast/class_defs.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,25 @@
1515
from .core import ClassDef, PyccelFunctionDef
1616
from .datatypes import (PythonNativeBool, PythonNativeInt, PythonNativeFloat,
1717
PythonNativeComplex, StringType, TupleType, CustomDataType,
18-
HomogeneousListType, HomogeneousSetType)
18+
HomogeneousListType, HomogeneousSetType, DictType)
1919
from .numpyext import (NumpyShape, NumpySum, NumpyAmin, NumpyAmax,
2020
NumpyImag, NumpyReal, NumpyTranspose,
2121
NumpyConjugate, NumpySize, NumpyResultType, NumpyArray)
2222
from .numpytypes import NumpyNumericType, NumpyNDArrayType
2323

2424
__all__ = (
2525
'BooleanClass',
26-
'IntegerClass',
27-
'FloatClass',
2826
'ComplexClass',
27+
'DictClass',
28+
'FloatClass',
29+
'IntegerClass',
30+
'ListClass',
31+
'NumpyArrayClass',
2932
'SetClass',
3033
'StringClass',
31-
'NumpyArrayClass',
3234
'TupleClass',
33-
'ListClass',
34-
'literal_classes',
3535
'get_cls_base',
36+
'literal_classes',
3637
)
3738

3839
#=======================================================================================
@@ -166,6 +167,12 @@
166167

167168
#=======================================================================================
168169

170+
DictClass = ClassDef('dict',
171+
methods=[
172+
])
173+
174+
#=======================================================================================
175+
169176
TupleClass = ClassDef('tuple',
170177
methods=[
171178
#index
@@ -254,6 +261,8 @@ def get_cls_base(class_type):
254261
return ListClass
255262
elif isinstance(class_type, HomogeneousSetType):
256263
return SetClass
264+
elif isinstance(class_type, DictType):
265+
return DictClass
257266
else:
258267
raise NotImplementedError(f"No class definition found for type {class_type}")
259268

pyccel/ast/datatypes.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -977,23 +977,23 @@ class DictType(ContainerType, metaclass = ArgumentSingleton):
977977
978978
Parameters
979979
----------
980-
index_type : PyccelType
980+
key_type : PyccelType
981981
The type of the keys of the homogeneous dictionary.
982982
value_type : PyccelType
983983
The type of the values of the homogeneous dictionary.
984984
"""
985-
__slots__ = ('_index_type', '_value_type')
986-
_name = 'map'
985+
__slots__ = ('_key_type', '_value_type')
986+
_name = 'dict'
987987
_container_rank = 1
988988
_order = None
989989

990-
def __init__(self, index_type, value_type):
991-
self._index_type = index_type
990+
def __init__(self, key_type, value_type):
991+
self._key_type = key_type
992992
self._value_type = value_type
993993
super().__init__()
994994

995995
def __str__(self):
996-
return f'map[{self._index_type.name}, {self._value_type.name}]'
996+
return f'dict[{self._key_type}, {self._value_type}]'
997997

998998
def __reduce__(self):
999999
"""
@@ -1009,7 +1009,7 @@ def __reduce__(self):
10091009
args
10101010
A tuple containing any arguments to be passed to the callable.
10111011
"""
1012-
return (self.__class__, (self._index_type, self._value_type))
1012+
return (self.__class__, (self._key_type, self._value_type))
10131013

10141014
@property
10151015
def datatype(self):
@@ -1018,7 +1018,56 @@ def datatype(self):
10181018
10191019
The datatype of the object.
10201020
"""
1021-
return self._index_type.datatype
1021+
return self._key_type.datatype
1022+
1023+
@property
1024+
def key_type(self):
1025+
"""
1026+
The type of the keys of the object.
1027+
1028+
The type of the keys of the object.
1029+
"""
1030+
return self._key_type
1031+
1032+
@property
1033+
def value_type(self):
1034+
"""
1035+
The type of the values of the object.
1036+
1037+
The type of the values of the object.
1038+
"""
1039+
return self._value_type
1040+
1041+
@property
1042+
def container_rank(self):
1043+
"""
1044+
Number of dimensions of the container.
1045+
1046+
Number of dimensions of the object described by the container. This is
1047+
equal to the number of values required to index an element of this container.
1048+
"""
1049+
return 1
1050+
1051+
@property
1052+
def rank(self):
1053+
"""
1054+
Number of dimensions of the object.
1055+
1056+
Number of dimensions of the object. If the object is a scalar then
1057+
this is equal to 0.
1058+
"""
1059+
return self._container_rank
1060+
1061+
@property
1062+
def order(self):
1063+
"""
1064+
The data layout ordering in memory.
1065+
1066+
Indicates whether the data is stored in row-major ('C') or column-major
1067+
('F') format. This is only relevant if rank > 1. When it is not relevant
1068+
this function returns None.
1069+
"""
1070+
return None
10221071

10231072
#==============================================================================
10241073

pyccel/codegen/printing/pycode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,10 @@ def _print_PythonSet(self, expr):
430430
args = ', '.join(self._print(i) for i in expr.args)
431431
return '{'+args+'}'
432432

433+
def _print_PythonDict(self, expr):
434+
args = ', '.join(f'{self._print(k)}: {self._print(v)}' for k,v in expr)
435+
return '{'+args+'}'
436+
433437
def _print_PythonBool(self, expr):
434438
return 'bool({})'.format(self._print(expr.arg))
435439

pyccel/parser/semantic.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pyccel.ast.basic import PyccelAstNode, TypedAstNode, ScopedAstNode
2323

2424
from pyccel.ast.builtins import PythonPrint, PythonTupleFunction, PythonSetFunction
25-
from pyccel.ast.builtins import PythonComplex
25+
from pyccel.ast.builtins import PythonComplex, PythonDict
2626
from pyccel.ast.builtins import builtin_functions_dict, PythonImag, PythonReal
2727
from pyccel.ast.builtins import PythonList, PythonConjugate , PythonSet
2828
from pyccel.ast.builtins import (PythonRange, PythonZip, PythonEnumerate,
@@ -2249,6 +2249,16 @@ def _visit_PythonSet(self, expr):
22492249
severity='fatal')
22502250
return expr
22512251

2252+
def _visit_PythonDict(self, expr):
2253+
keys = [self._visit(k) for k in expr.keys]
2254+
vals = [self._visit(v) for v in expr.values]
2255+
try:
2256+
expr = PythonDict(keys, vals)
2257+
except TypeError as e:
2258+
errors.report(str(e), symbol=expr,
2259+
severity='fatal')
2260+
return expr
2261+
22522262
def _visit_FunctionCallArgument(self, expr):
22532263
value = self._visit(expr.value)
22542264
a = FunctionCallArgument(value, expr.keyword)

pyccel/parser/syntactic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pyccel.ast.operators import IfTernaryOperator
5252
from pyccel.ast.numpyext import NumpyMatmul
5353

54-
from pyccel.ast.builtins import PythonTuple, PythonList, PythonSet
54+
from pyccel.ast.builtins import PythonTuple, PythonList, PythonSet, PythonDict
5555
from pyccel.ast.builtins import PythonPrint, Lambda
5656
from pyccel.ast.headers import MetaVariable, FunctionHeader, MethodHeader
5757
from pyccel.ast.literals import LiteralInteger, LiteralFloat, LiteralComplex
@@ -399,8 +399,7 @@ def _visit_alias(self, stmt):
399399
return old
400400

401401
def _visit_Dict(self, stmt):
402-
errors.report(PYCCEL_RESTRICTION_TODO,
403-
symbol=stmt, severity='error')
402+
return PythonDict(self._visit(stmt.keys), self._visit(stmt.values))
404403

405404
def _visit_NoneType(self, stmt):
406405
return Nil()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# pylint: disable=missing-function-docstring, missing-module-docstring
2+
import pytest
3+
from pyccel import epyccel
4+
5+
@pytest.fixture( params=[
6+
pytest.param("fortran", marks = [
7+
pytest.mark.skip(reason="dict methods not implemented in fortran"),
8+
pytest.mark.fortran]),
9+
pytest.param("c", marks = [
10+
pytest.mark.skip(reason="dict methods not implemented in c"),
11+
pytest.mark.c]),
12+
pytest.param("python", marks = pytest.mark.python)
13+
],
14+
scope = "module"
15+
)
16+
def language(request):
17+
return request.param
18+
19+
def test_dict_init(language):
20+
def dict_init():
21+
a = {1:1.0, 2:2.0}
22+
return a
23+
epyc_dict_init = epyccel(dict_init, language = language)
24+
pyccel_result = epyc_dict_init()
25+
python_result = dict_init()
26+
assert isinstance(python_result, type(pyccel_result))
27+
assert python_result == pyccel_result
28+
29+
def test_dict_str_keys(language):
30+
def dict_str_keys():
31+
a = {'a':1, 'b':2}
32+
return a
33+
epyc_str_keys = epyccel(dict_str_keys, language = language)
34+
pyccel_result = epyc_str_keys()
35+
python_result = dict_str_keys()
36+
assert isinstance(python_result, type(pyccel_result))
37+
assert python_result == pyccel_result

tests/errors/syntax_blockers/ex1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# pylint: disable=missing-function-docstring, missing-module-docstring
2-
{a: 2, 'b':4}
32
~a
43

54
#$ this is invalid comment

0 commit comments

Comments
 (0)