|
65 | 65 | from pyccel.ast.datatypes import PythonNativeBool, PythonNativeInt, PythonNativeFloat
|
66 | 66 | from pyccel.ast.datatypes import DataTypeFactory, PrimitiveFloatingPointType
|
67 | 67 | from pyccel.ast.datatypes import InhomogeneousTupleType, HomogeneousTupleType, HomogeneousSetType, HomogeneousListType
|
68 |
| -from pyccel.ast.datatypes import PrimitiveComplexType, FixedSizeNumericType, DictType |
| 68 | +from pyccel.ast.datatypes import PrimitiveComplexType, FixedSizeNumericType, DictType, TypeAlias |
69 | 69 |
|
70 | 70 | from pyccel.ast.functionalexpr import FunctionalSum, FunctionalMax, FunctionalMin, GeneratorComprehension, FunctionalFor
|
71 | 71 |
|
@@ -368,7 +368,7 @@ def check_for_variable(self, name):
|
368 | 368 | class_def = prefix.cls_base
|
369 | 369 | except AttributeError:
|
370 | 370 | class_def = get_cls_base(prefix.class_type) or \
|
371 |
| - self.scope.find(prefix.class_type.name, 'classes') |
| 371 | + self.scope.find(str(prefix.class_type), 'classes') |
372 | 372 |
|
373 | 373 | attr_name = name.name[-1]
|
374 | 374 | class_scope = class_def.scope
|
@@ -689,6 +689,9 @@ def _infer_type(self, expr):
|
689 | 689 | dict
|
690 | 690 | Dictionary containing all the type information which was inferred.
|
691 | 691 | """
|
| 692 | + if not isinstance(expr, TypedAstNode): |
| 693 | + return {'class_type' : SymbolicType()} |
| 694 | + |
692 | 695 | d_var = {
|
693 | 696 | 'class_type' : expr.class_type,
|
694 | 697 | 'shape' : expr.shape,
|
@@ -1410,6 +1413,10 @@ def _assign_lhs_variable(self, lhs, d_var, rhs, new_expressions, is_augassign,ar
|
1410 | 1413 | else:
|
1411 | 1414 | var = None
|
1412 | 1415 | else:
|
| 1416 | + symbolic_var = self.scope.find(lhs, 'symbolic_alias') |
| 1417 | + if symbolic_var: |
| 1418 | + errors.report(f"{lhs} variable represents a symbolic concept. Its value cannot be changed.", |
| 1419 | + severity='fatal') |
1413 | 1420 | var = self.scope.find(lhs)
|
1414 | 1421 |
|
1415 | 1422 | # Variable not yet declared (hence array not yet allocated)
|
@@ -1936,6 +1943,8 @@ def _get_indexed_type(self, base, args, expr):
|
1936 | 1943 | for t in annotation.type_list:
|
1937 | 1944 | t.is_const = True
|
1938 | 1945 | return annotation
|
| 1946 | + elif isinstance(base, UnionTypeAnnotation): |
| 1947 | + return UnionTypeAnnotation(*[self._get_indexed_type(t, args, expr) for t in base.type_list]) |
1939 | 1948 |
|
1940 | 1949 | if all(isinstance(a, Slice) for a in args):
|
1941 | 1950 | rank = len(args)
|
@@ -2472,7 +2481,7 @@ def _visit_Slice(self, expr):
|
2472 | 2481 | def _visit_IndexedElement(self, expr):
|
2473 | 2482 | var = self._visit(expr.base)
|
2474 | 2483 |
|
2475 |
| - if isinstance(var, (PyccelFunctionDef, VariableTypeAnnotation)): |
| 2484 | + if isinstance(var, (PyccelFunctionDef, VariableTypeAnnotation, UnionTypeAnnotation)): |
2476 | 2485 | return self._get_indexed_type(var, expr.indices, expr)
|
2477 | 2486 |
|
2478 | 2487 | # TODO check consistency of indices with shape/rank
|
@@ -2584,7 +2593,7 @@ def _visit_AnnotatedPyccelSymbol(self, expr):
|
2584 | 2593 | possible_args.append(address)
|
2585 | 2594 | elif isinstance(t, VariableTypeAnnotation):
|
2586 | 2595 | class_type = t.class_type
|
2587 |
| - cls_base = get_cls_base(class_type) or self.scope.find(class_type.name, 'classes') |
| 2596 | + cls_base = self.scope.find(str(class_type), 'classes') or get_cls_base(class_type) |
2588 | 2597 | v = var_class(class_type, name, cls_base = cls_base,
|
2589 | 2598 | shape = None,
|
2590 | 2599 | is_const = t.is_const, is_optional = False,
|
@@ -2628,6 +2637,8 @@ def _visit_SyntacticTypeAnnotation(self, expr):
|
2628 | 2637 | raise errors.report(PYCCEL_RESTRICTION_TODO + ' Could not deduce type information',
|
2629 | 2638 | severity='fatal', symbol=expr)
|
2630 | 2639 |
|
| 2640 | + def _visit_VariableTypeAnnotation(self, expr): |
| 2641 | + return expr |
2631 | 2642 |
|
2632 | 2643 | def _visit_DottedName(self, expr):
|
2633 | 2644 |
|
@@ -2698,7 +2709,7 @@ def _visit_DottedName(self, expr):
|
2698 | 2709 | class_type = d_var['class_type']
|
2699 | 2710 | cls_base = get_cls_base(class_type)
|
2700 | 2711 | if cls_base is None:
|
2701 |
| - cls_base = self.scope.find(class_type.name, 'classes') |
| 2712 | + cls_base = self.scope.find(str(class_type), 'classes') |
2702 | 2713 |
|
2703 | 2714 | # look for a class method
|
2704 | 2715 | if isinstance(rhs, FunctionCall):
|
@@ -2980,6 +2991,35 @@ def _visit_Assign(self, expr):
|
2980 | 2991 | rhs = expr.rhs
|
2981 | 2992 | lhs = expr.lhs
|
2982 | 2993 |
|
| 2994 | + if isinstance(lhs, AnnotatedPyccelSymbol): |
| 2995 | + semantic_lhs = self._visit(lhs) |
| 2996 | + if len(semantic_lhs) != 1: |
| 2997 | + errors.report("Cannot declare variable with multiple types", |
| 2998 | + symbol=expr, severity='error') |
| 2999 | + semantic_lhs_var = semantic_lhs[0] |
| 3000 | + if isinstance(semantic_lhs_var, DottedVariable): |
| 3001 | + cls_def = semantic_lhs_var.lhs.cls_base |
| 3002 | + insert_scope = cls_def.scope |
| 3003 | + cls_def.add_new_attribute(semantic_lhs_var) |
| 3004 | + else: |
| 3005 | + insert_scope = self.scope |
| 3006 | + |
| 3007 | + lhs = lhs.name |
| 3008 | + if semantic_lhs_var.class_type is TypeAlias(): |
| 3009 | + if not isinstance(rhs, SyntacticTypeAnnotation): |
| 3010 | + pyccel_stage.set_stage('syntactic') |
| 3011 | + rhs = SyntacticTypeAnnotation(rhs) |
| 3012 | + pyccel_stage.set_stage('semantic') |
| 3013 | + type_annot = self._visit(rhs) |
| 3014 | + self.scope.insert_symbolic_alias(lhs, type_annot) |
| 3015 | + return EmptyNode() |
| 3016 | + |
| 3017 | + try: |
| 3018 | + insert_scope.insert_variable(semantic_lhs_var) |
| 3019 | + except RuntimeError as e: |
| 3020 | + errors.report(e, symbol=expr, severity='error') |
| 3021 | + |
| 3022 | + |
2983 | 3023 | # Steps before visiting
|
2984 | 3024 | if isinstance(rhs, GeneratorComprehension):
|
2985 | 3025 | rhs.substitute(rhs.lhs, lhs)
|
@@ -3024,10 +3064,12 @@ def _visit_Assign(self, expr):
|
3024 | 3064 | d_m_args = {arg.value.name:arg.value for arg in macro.master_arguments
|
3025 | 3065 | if isinstance(arg.value, Variable)}
|
3026 | 3066 |
|
3027 |
| - if not sympy_iterable(lhs): |
3028 |
| - lhs = [lhs] |
| 3067 | + lhs_iter = lhs |
| 3068 | + |
| 3069 | + if not sympy_iterable(lhs_iter): |
| 3070 | + lhs_iter = [lhs] |
3029 | 3071 | results_shapes = macro.get_results_shapes(args)
|
3030 |
| - for m_result, shape, result in zip(macro.results, results_shapes, lhs): |
| 3072 | + for m_result, shape, result in zip(macro.results, results_shapes, lhs_iter): |
3031 | 3073 | if m_result in d_m_args and not result in args_names:
|
3032 | 3074 | d_result = self._infer_type(d_m_args[m_result])
|
3033 | 3075 | d_result['shape'] = shape
|
@@ -3071,14 +3113,6 @@ def _visit_Assign(self, expr):
|
3071 | 3113 | return rhs
|
3072 | 3114 | if isinstance(rhs, ConstructorCall):
|
3073 | 3115 | return rhs
|
3074 |
| - elif isinstance(rhs, FunctionDef): |
3075 |
| - |
3076 |
| - # case of lambdify |
3077 |
| - |
3078 |
| - rhs = rhs.rename(expr.lhs.name) |
3079 |
| - for i in rhs.body: |
3080 |
| - i.set_current_ast(python_ast) |
3081 |
| - return rhs |
3082 | 3116 |
|
3083 | 3117 | elif isinstance(rhs, CodeBlock) and len(rhs.body)>1 and isinstance(rhs.body[1], FunctionalFor):
|
3084 | 3118 | return rhs
|
@@ -3148,25 +3182,6 @@ def _visit_Assign(self, expr):
|
3148 | 3182 | # case of rhs is a target variable the lhs must be a pointer
|
3149 | 3183 | d['memory_handling'] = 'alias'
|
3150 | 3184 |
|
3151 |
| - lhs = expr.lhs |
3152 |
| - if isinstance(lhs, AnnotatedPyccelSymbol): |
3153 |
| - semantic_lhs = self._visit(lhs) |
3154 |
| - if len(semantic_lhs) != 1: |
3155 |
| - errors.report("Cannot declare variable with multiple types", |
3156 |
| - symbol=expr, severity='error') |
3157 |
| - semantic_lhs_var = semantic_lhs[0] |
3158 |
| - if isinstance(semantic_lhs_var, DottedVariable): |
3159 |
| - cls_def = semantic_lhs_var.lhs.cls_base |
3160 |
| - insert_scope = cls_def.scope |
3161 |
| - cls_def.add_new_attribute(semantic_lhs_var) |
3162 |
| - else: |
3163 |
| - insert_scope = self.scope |
3164 |
| - try: |
3165 |
| - insert_scope.insert_variable(semantic_lhs_var) |
3166 |
| - except RuntimeError as e: |
3167 |
| - errors.report(e, symbol=expr, severity='error') |
3168 |
| - lhs = lhs.name |
3169 |
| - |
3170 | 3185 | if isinstance(lhs, (PyccelSymbol, DottedName)):
|
3171 | 3186 | if isinstance(d_var, list):
|
3172 | 3187 | if len(d_var) == 1:
|
@@ -3321,12 +3336,9 @@ def _visit_Assign(self, expr):
|
3321 | 3336 | # it is then treated as a def node
|
3322 | 3337 |
|
3323 | 3338 | F = self.scope.find(l, 'symbolic_functions')
|
3324 |
| - if F is None: |
3325 |
| - self.insert_symbolic_function(new_expr) |
3326 |
| - else: |
3327 |
| - errors.report(PYCCEL_RESTRICTION_TODO, |
3328 |
| - bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset), |
3329 |
| - severity='fatal') |
| 3339 | + errors.report(PYCCEL_RESTRICTION_TODO, |
| 3340 | + bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset), |
| 3341 | + severity='fatal') |
3330 | 3342 |
|
3331 | 3343 | new_expressions.append(new_expr)
|
3332 | 3344 |
|
|
0 commit comments