Skip to content

Commit 267662e

Browse files
authored
Clean STC imports (pyccel#1930)
Handling imports by separating concepts with limiters is potentially error prone. It also results in lots of special cases being written to handle each individual import (e.g. vec vs. Set_extensions). This PR unifies the strategy by using existing Pyccel concepts. An `Import` now takes the source file as the source argument (instead of the string containing limiters) and an `AsName` as the target. The object in the `AsName` is the PyccelType being handled by this import (described by a VariableTypeAnnotation) and the `target` (local name) of the `AsName` is the C name of this object. **Commit Summary** - Improve `AsName.__repr__` to help debugging when the object has no name - Ensure `Import` targets are ordered - Unify the header guard strategy - Use Pyccel objects to group concepts instead of limiters
1 parent e4dfd4f commit 267662e

File tree

6 files changed

+83
-53
lines changed

6 files changed

+83
-53
lines changed

.dict_custom.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,5 @@ datatyping
117117
datatypes
118118
indexable
119119
traceback
120+
STC
121+
gFTL

CHANGELOG.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ All notable changes to this project will be documented in this file.
1616
- #1750 : Add Python support for set method `remove()`.
1717
- #1743 : Add Python support for set method `discard()`.
1818
- #1754 : Add Python support for set method `update()`.
19-
- #1787 : Ensure `STC` is installed with Pyccel.
20-
- #1656 : Ensure `gFTL` is installed with Pyccel.
19+
- #1787 : Ensure STC is installed with Pyccel.
20+
- #1656 : Ensure gFTL is installed with Pyccel.
2121
- #1844 : Add line numbers and code to errors from built-in function calls.
22-
- #1655 : Add the appropriate C language equivalent for declaring a Python `list` container using the `STC` library.
23-
- #1659 : Add the appropriate C language equivalent for declaring a Python `set` container using the `STC` library.
22+
- #1655 : Add the appropriate C language equivalent for declaring a Python `list` container using the STC library.
23+
- #1659 : Add the appropriate C language equivalent for declaring a Python `set` container using the STC library.
2424
- #1893 : Add Python support for set initialisation with `set()`.
2525
- #1877 : Add C Support for set method `pop()`.
2626
- \[INTERNALS\] Added `container_rank` property to `ast.datatypes.PyccelType` objects.
@@ -44,6 +44,7 @@ All notable changes to this project will be documented in this file.
4444
- #1913 : Fix function calls to renamed functions.
4545
- #1927 : Improve error Message for missing target language compiler in Pyccel
4646
- #1933 : Improve code printing speed.
47+
- #1930 : Preserve ordering of import targets.
4748

4849
### Changed
4950

@@ -71,6 +72,7 @@ All notable changes to this project will be documented in this file.
7172
- \[INTERNALS\] `PyccelFunction` objects which do not represent objects in memory have the type `SymbolicType`.
7273
- \[INTERNALS\] Rename `_visit` functions called from a `FunctionCall` which don't match the documented naming pattern to `_build` functions.
7374
- \[INTERNALS\] Remove unnecessary argument `kind` to `Errors.set_target`.
75+
- \[INTERNALS\] Handle STC imports with Pyccel objects.
7476

7577
### Deprecated
7678

pyccel/ast/core.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def object(self):
164164
return self._obj
165165

166166
def __repr__(self):
167-
return f'{self.name} as {self.target}'
167+
return f'{self.object} as {self.target}'
168168

169169
def __eq__(self, string):
170170
if isinstance(string, str):
@@ -3713,7 +3713,7 @@ def __init__(self, source, target = None, ignore_at_print = False, mod = None):
37133713
source = Import._format(source)
37143714

37153715
self._source = source
3716-
self._target = set()
3716+
self._target = {} # Dict is used as Python doesn't have an ordered set
37173717
self._source_mod = mod
37183718
self._ignore_at_print = ignore_at_print
37193719

@@ -3729,14 +3729,14 @@ def __init__(self, source, target = None, ignore_at_print = False, mod = None):
37293729
target = [target]
37303730
if pyccel_stage == "syntactic":
37313731
for i in target:
3732-
self._target.add(Import._format(i))
3732+
self._target[Import._format(i)] = None
37333733
else:
37343734
for i in target:
37353735
assert isinstance(i, (AsName, Module))
37363736
if isinstance(i, Module):
3737-
self._target.add(AsName(i,source))
3737+
self._target[AsName(i,source)] = None
37383738
else:
3739-
self._target.add(i)
3739+
self._target[i] = None
37403740
super().__init__()
37413741

37423742
@staticmethod
@@ -3775,7 +3775,12 @@ def _format(i):
37753775

37763776
@property
37773777
def target(self):
3778-
return self._target
3778+
"""
3779+
Get the objects that are being imported.
3780+
3781+
Get the objects that are being imported.
3782+
"""
3783+
return self._target.keys()
37793784

37803785
@property
37813786
def source(self):
@@ -3801,23 +3806,26 @@ def __str__(self):
38013806

38023807
def define_target(self, new_target):
38033808
"""
3804-
Add an additional target to the imports
3809+
Add an additional target to the imports.
3810+
3811+
Add an additional target to the imports.
38053812
I.e. if imp is an Import defined as:
38063813
>>> from numpy import ones
38073814
38083815
and we call imp.define_target('cos')
38093816
then it becomes:
38103817
>>> from numpy import ones, cos
38113818
3812-
Parameter
3813-
---------
3814-
new_target: str/AsName/iterable of str/AsName
3815-
The new import target
3819+
Parameters
3820+
----------
3821+
new_target : str | AsName | iterable[str | AsName]
3822+
The new import target.
38163823
"""
3824+
assert pyccel_stage != "syntactic"
38173825
if iterable(new_target):
3818-
self._target.update(new_target)
3826+
self._target.update({t: None for t in new_target})
38193827
else:
3820-
self._target.add(new_target)
3828+
self._target[new_target] = None
38213829

38223830
def find_module_target(self, new_target):
38233831
for t in self._target:

pyccel/codegen/printing/ccode.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from pyccel.ast.numpytypes import NumpyFloat32Type, NumpyFloat64Type, NumpyComplex64Type, NumpyComplex128Type
4747
from pyccel.ast.numpytypes import NumpyNDArrayType, numpy_precision_map
4848

49+
from pyccel.ast.type_annotations import VariableTypeAnnotation
50+
4951
from pyccel.ast.utilities import expand_to_loops
5052

5153
from pyccel.ast.variable import IndexedElement
@@ -239,6 +241,8 @@
239241
'assert',
240242
'numpy_c']}
241243

244+
import_header_guard_prefix = {'Set_extensions' : '_TOOLS_SET'}
245+
242246
class CCodePrinter(CodePrinter):
243247
"""
244248
A printer for printing code in C.
@@ -312,8 +316,22 @@ def get_additional_imports(self):
312316
return self._additional_imports.keys()
313317

314318
def add_import(self, import_obj):
319+
"""
320+
Add a new import to the current context.
321+
322+
Add a new import to the current context. This allows the import to be recognised
323+
at the compiling/linking stage. If the source of the import is not new then any
324+
new targets are added to the Import object.
325+
326+
Parameters
327+
----------
328+
import_obj : Import
329+
The AST node describing the import.
330+
"""
315331
if import_obj.source not in self._additional_imports:
316332
self._additional_imports[import_obj.source] = import_obj
333+
elif import_obj.target:
334+
self._additional_imports[import_obj.source].define_target(import_obj.target)
317335

318336
def _get_statement(self, codestring):
319337
return "%s;\n" % codestring
@@ -996,24 +1014,21 @@ def _print_Import(self, expr):
9961014
source = source.name[-1]
9971015
else:
9981016
source = self._print(source)
999-
if source.startswith('stc/'):
1000-
stc_name, container_type, container_key = source.split("/")
1001-
container = container_type.split("_")
1002-
return '\n'.join((f'#ifndef _{container_type.upper()}',
1003-
f'#define _{container_type.upper()}',
1004-
f'#define i_type {container_type}',
1005-
f'#define i_key {container_key}',
1006-
f'#include "{stc_name + "/" + container[0]}.h"',
1007-
'#endif\n'))
1008-
elif source.startswith('Set_pop'):
1009-
_ , i_type, i_key = source.split('/')
1010-
self.add_import(Import('STC_Extensions', Module('STC_Extensions', (), ())))
1011-
return '\n'.join((
1012-
f'#ifndef TOOLS_SET_{str(i_key).upper()}\n'
1013-
f'#define TOOLS_SET_{str(i_key).upper()}\n'
1014-
f'#define i_type {i_type}',
1015-
f'#define i_key {i_key}\n',
1016-
'#include "Set_extensions.h"\n#endif\n'))
1017+
if source.startswith('stc/') or source in import_header_guard_prefix:
1018+
code = ''
1019+
for t in expr.target:
1020+
dtype = t.object.class_type
1021+
container_type = t.target
1022+
container_key = self.get_c_type(dtype.element_type)
1023+
header_guard_prefix = import_header_guard_prefix.get(source, '')
1024+
header_guard = f'{header_guard_prefix}_{container_type.upper()}'
1025+
code += (f'#ifndef {header_guard}\n'
1026+
f'#define {header_guard}\n'
1027+
f'#define i_type {container_type}\n'
1028+
f'#define i_key {container_key}\n'
1029+
f'#include <{source}.h>\n'
1030+
f'#endif // {header_guard}\n\n')
1031+
return code
10171032
# Get with a default value is not used here as it is
10181033
# slower and on most occasions the import will not be in the
10191034
# dictionary
@@ -1231,10 +1246,10 @@ def get_c_type(self, dtype):
12311246

12321247
key = (primitive_type, dtype.precision)
12331248
elif isinstance(dtype, (HomogeneousSetType, HomogeneousListType)):
1234-
container_type = 'hset_' if dtype.name == 'set' else 'vec_'
1235-
element_type = self.get_c_type(dtype.element_type)
1236-
i_type = container_type + element_type.replace(' ', '_')
1237-
self.add_import(Import(f'stc/{i_type}/{element_type}', Module(f'stc/{i_type}', (), ())))
1249+
container_type = 'hset' if dtype.name == 'set' else 'vec'
1250+
element_type = self.get_c_type(dtype.element_type).replace(' ', '_')
1251+
i_type = f'{container_type}_{element_type}'
1252+
self.add_import(Import(f'stc/{container_type}', AsName(VariableTypeAnnotation(dtype), i_type)))
12381253
return i_type
12391254
else:
12401255
key = dtype
@@ -2193,9 +2208,9 @@ def _print_Assign(self, expr):
21932208
return prefix_code+'{} = {};\n'.format(lhs, rhs)
21942209

21952210
def _print_SetPop(self, expr):
2196-
var_type = self.get_declare_type(expr.set_variable)
2197-
element_type = self.get_c_type(expr.set_variable.class_type.element_type)
2198-
self.add_import(Import(f'Set_pop_macro/{var_type}/{element_type}', Module(f'Set_pop_macro/{var_type}/{element_type}', (), ())))
2211+
dtype = expr.set_variable.class_type
2212+
var_type = self.get_c_type(dtype)
2213+
self.add_import(Import('Set_extensions', AsName(VariableTypeAnnotation(dtype), var_type)))
21992214
set_var = self._print(ObjectAddress(expr.set_variable))
22002215
return f'{var_type}_pop({set_var})'
22012216

pyccel/codegen/utilities.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@
3636
# map internal libraries to their folders inside pyccel/stdlib and their compile objects
3737
# The compile object folder will be in the pyccel dirpath
3838
internal_libs = {
39-
"ndarrays" : ("ndarrays", CompileObj("ndarrays.c",folder="ndarrays")),
40-
"pyc_math_f90" : ("math", CompileObj("pyc_math_f90.f90",folder="math")),
41-
"pyc_math_c" : ("math", CompileObj("pyc_math_c.c",folder="math")),
42-
"cwrapper" : ("cwrapper", CompileObj("cwrapper.c",folder="cwrapper", accelerators=('python',))),
43-
"numpy_f90" : ("numpy", CompileObj("numpy_f90.f90",folder="numpy")),
44-
"numpy_c" : ("numpy", CompileObj("numpy_c.c",folder="numpy")),
45-
"STC_Extensions" : ("STC_Extensions", CompileObj("Set_Extensions.h",folder="STC_Extensions", has_target_file = False)),
39+
"ndarrays" : ("ndarrays", CompileObj("ndarrays.c",folder="ndarrays")),
40+
"pyc_math_f90" : ("math", CompileObj("pyc_math_f90.f90",folder="math")),
41+
"pyc_math_c" : ("math", CompileObj("pyc_math_c.c",folder="math")),
42+
"cwrapper" : ("cwrapper", CompileObj("cwrapper.c",folder="cwrapper", accelerators=('python',))),
43+
"numpy_f90" : ("numpy", CompileObj("numpy_f90.f90",folder="numpy")),
44+
"numpy_c" : ("numpy", CompileObj("numpy_c.c",folder="numpy")),
45+
"Set_extensions" : ("STC_Extensions", CompileObj("Set_Extensions.h", folder="STC_Extensions", has_target_file = False)),
4646
}
4747
internal_libs["cwrapper_ndarrays"] = ("cwrapper_ndarrays", CompileObj("cwrapper_ndarrays.c",folder="cwrapper_ndarrays",
4848
accelerators = ('python',),

pyccel/parser/semantic.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4367,7 +4367,7 @@ def _insert_obj(location, target, obj):
43674367
__module_name__ = p.metavars.get('module_name', None)
43684368

43694369
if source_target in container['imports']:
4370-
targets = list(container['imports'][source_target].target.union(targets))
4370+
targets.extend(container['imports'][source_target].target)
43714371

43724372
if import_init:
43734373
old_name = import_init.name
@@ -4663,11 +4663,12 @@ def _build_MathSqrt(self, func_call):
46634663
fabs_name = self.scope.get_new_name('fabs')
46644664
imp_name = AsName('fabs', fabs_name)
46654665
new_import = Import('math',imp_name)
4666-
self._visit(new_import)
46674666
new_call = FunctionCall(fabs_name, [mul1])
46684667

46694668
pyccel_stage.set_stage('semantic')
46704669

4670+
self._visit(new_import)
4671+
46714672
return self._visit(new_call)
46724673
elif isinstance(arg.value, PyccelPow):
46734674
base, exponent = arg.value.args
@@ -4677,11 +4678,12 @@ def _build_MathSqrt(self, func_call):
46774678
fabs_name = self.scope.get_new_name('fabs')
46784679
imp_name = AsName('fabs', fabs_name)
46794680
new_import = Import('math',imp_name)
4680-
self._visit(new_import)
46814681
new_call = FunctionCall(fabs_name, [base])
46824682

46834683
pyccel_stage.set_stage('semantic')
46844684

4685+
self._visit(new_import)
4686+
46854687
return self._visit(new_call)
46864688

46874689
return self._handle_function(func_call, func, (arg,), use_build_functions = False)
@@ -4723,11 +4725,12 @@ def _build_CmathSqrt(self, func_call):
47234725
abs_name = self.scope.get_new_name('abs')
47244726
imp_name = AsName('abs', abs_name)
47254727
new_import = Import('numpy',imp_name)
4726-
self._visit(new_import)
47274728
new_call = FunctionCall(abs_name, [abs_arg])
47284729

47294730
pyccel_stage.set_stage('semantic')
47304731

4732+
self._visit(new_import)
4733+
47314734
# Cast to preserve final dtype
47324735
return PythonComplex(self._visit(new_call))
47334736

0 commit comments

Comments
 (0)