Skip to content

Commit 8445dd7

Browse files
EmilyBourneyguclu
andauthored
Fix translation for conflicting module name (pyccel#1854)
Fix the translation of a file whose name conflicts with Fortran keywords by ensuring that the original name is correctly extracted from the scope. Fixes pyccel#1853. **Commit Summary** - Ensure Python name of module exists in the scope by adding it in the syntactic stage. - Don't use `AsName` for module name - Save the Python name of a module as the name of the `PyModule` - Ensure Python names are used for imports - Remove hacky `set_name` function. - Remove unused `assign_to` argument of `CodePrinter.doprint`. --------- Co-authored-by: Yaman Güçlü <[email protected]>
1 parent d422cac commit 8445dd7

File tree

15 files changed

+97
-40
lines changed

15 files changed

+97
-40
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ All notable changes to this project will be documented in this file.
3939
- #1830 : Fix missing allocation when returning an annotated array expression.
4040
- #1821 : Ensure an error is raised when creating an ambiguous interface.
4141
- #1842 : Fix homogeneous tuples incorrectly identified as inhomogeneous.
42+
- #1853 : Fix translation of a file whose name conflicts with Fortran keywords.
4243

4344
### Changed
4445

@@ -83,6 +84,8 @@ All notable changes to this project will be documented in this file.
8384
- \[INTERNALS\] Remove `pyccel.ast.utilities.builtin_functions`.
8485
- \[INTERNALS\] Remove unused/unnecessary functions in `pyccel.parser.utilities` : `read_file`, `header_statement`, `accelerator_statement`, `get_module_name`, `view_tree`.
8586
- \[INTERNALS\] Remove unused functions `Errors.unset_target`, and `Errors.reset_target`.
87+
- \[INTERNALS\] Remove function `Module.set_name`.
88+
- \[INTERNALS\] Remove unused `assign_to` argument of `CodePrinter.doprint`.
8689

8790
## \[1.11.2\] - 2024-03-05
8891

pyccel/ast/core.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,8 +1023,8 @@ def __init__(
10231023
imports=(),
10241024
scope = None
10251025
):
1026-
if not isinstance(name, (str, AsName)):
1027-
raise TypeError('name must be a string or an AsName')
1026+
if not isinstance(name, str):
1027+
raise TypeError('name must be a string')
10281028

10291029
if not iterable(variables):
10301030
raise TypeError('variables must be an iterable')
@@ -1179,11 +1179,6 @@ def body(self):
11791179
"""
11801180
return self.interfaces + self.funcs + self.classes
11811181

1182-
def set_name(self, new_name):
1183-
""" Function for changing the name of a module
1184-
"""
1185-
self._name = new_name
1186-
11871182
def __getitem__(self, arg):
11881183
assert isinstance(arg, str)
11891184
args = arg.split('.')
@@ -3721,6 +3716,10 @@ def __init__(self, source, target = None, ignore_at_print = False, mod = None):
37213716
self._target = set()
37223717
self._source_mod = mod
37233718
self._ignore_at_print = ignore_at_print
3719+
3720+
if mod is None and isinstance(target, Module):
3721+
self._source_mod = target
3722+
37243723
if target is None:
37253724
if pyccel_stage == "syntactic":
37263725
target = []

pyccel/codegen/printing/ccode.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ def _print_Header(self, expr):
759759

760760
def _print_ModuleHeader(self, expr):
761761
self.set_scope(expr.module.scope)
762+
self._current_module = expr.module
762763
self._in_header = True
763764
name = expr.module.name
764765
if isinstance(name, AsName):
@@ -792,6 +793,7 @@ def _print_ModuleHeader(self, expr):
792793

793794
self._in_header = False
794795
self.exit_scope()
796+
self._current_module = None
795797
return (f"#ifndef {name.upper()}_H\n \
796798
#define {name.upper()}_H\n\n \
797799
{imports}\n \
@@ -802,13 +804,13 @@ def _print_ModuleHeader(self, expr):
802804

803805
def _print_Module(self, expr):
804806
self.set_scope(expr.scope)
805-
self._current_module = expr.name
807+
self._current_module = expr
806808
body = ''.join(self._print(i) for i in expr.body)
807809

808810
global_variables = ''.join([self._print(d) for d in expr.declarations])
809811

810812
# Print imports last to be sure that all additional_imports have been collected
811-
imports = [Import(expr.name, Module(expr.name,(),())), *self._additional_imports.values()]
813+
imports = [Import(self.scope.get_python_name(expr.name), Module(expr.name,(),())), *self._additional_imports.values()]
812814
imports = ''.join(self._print(i) for i in imports)
813815

814816
code = ('{imports}\n'
@@ -819,6 +821,7 @@ def _print_Module(self, expr):
819821
body = body)
820822

821823
self.exit_scope()
824+
self._current_module = None
822825
return code
823826

824827
def _print_Break(self, expr):
@@ -951,7 +954,7 @@ def _print_Import(self, expr):
951954
if source in import_dict: # pylint: disable=consider-using-get
952955
source = import_dict[source]
953956

954-
if expr.source_module:
957+
if expr.source_module and expr.source_module is not self._current_module:
955958
for classDef in expr.source_module.classes:
956959
class_scope = classDef.scope
957960
for method in classDef.methods:
@@ -2426,6 +2429,8 @@ def _print_Omp_End_Clause(self, expr):
24262429
#=====================================
24272430

24282431
def _print_Program(self, expr):
2432+
mod = expr.get_direct_user_nodes(lambda x: isinstance(x, Module))[0]
2433+
self._current_module = mod
24292434
self.set_scope(expr.scope)
24302435
body = self._print(expr.body)
24312436
variables = self.scope.variables.values()
@@ -2435,6 +2440,7 @@ def _print_Program(self, expr):
24352440
imports = ''.join(self._print(i) for i in imports)
24362441

24372442
self.exit_scope()
2443+
self._current_module = None
24382444
return ('{imports}'
24392445
'int main()\n{{\n'
24402446
'{decs}'

pyccel/codegen/printing/codeprinter.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pyccel.ast.basic import PyccelAstNode
99

10-
from pyccel.ast.core import Assign
10+
from pyccel.ast.core import Module, ModuleHeader, Program
1111
from pyccel.ast.internals import PyccelSymbol
1212

1313
from pyccel.errors.errors import Errors
@@ -22,31 +22,31 @@
2222
class CodePrinter:
2323
"""
2424
The base class for code-printing subclasses.
25+
26+
The base class from which code printers inherit. The sub-classes should define a language
27+
and `_print_X` functions.
2528
"""
2629
language = None
2730
def __init__(self):
2831
self._scope = None
2932

30-
def doprint(self, expr, assign_to=None):
33+
def doprint(self, expr):
3134
"""
3235
Print the expression as code.
3336
37+
Print the expression as code.
38+
39+
Parameters
40+
----------
3441
expr : Expression
3542
The expression to be printed.
3643
37-
assign_to : PyccelSymbol, MatrixSymbol, or string (optional)
38-
If provided, the printed code will set the expression to a
39-
variable with name ``assign_to``.
44+
Returns
45+
-------
46+
str
47+
The generated code.
4048
"""
41-
42-
if isinstance(assign_to, str):
43-
assign_to = PyccelSymbol(assign_to)
44-
elif not isinstance(assign_to, (PyccelAstNode, type(None))):
45-
raise TypeError("{0} cannot assign to object of type {1}".format(
46-
type(self).__name__, type(assign_to)))
47-
48-
if assign_to:
49-
expr = Assign(assign_to, expr)
49+
assert isinstance(expr, (Module, ModuleHeader, Program))
5050

5151
# Do the actual printing
5252
lines = self._print(expr).splitlines(True)

pyccel/codegen/printing/cwrappercode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def _print_PyModule_Create(self, expr):
260260

261261
def _print_ModuleHeader(self, expr):
262262
mod = expr.module
263+
self._current_module = expr.module
263264
name = mod.name
264265

265266
# Print imports last to be sure that all additional_imports have been collected
@@ -293,6 +294,7 @@ def _print_ModuleHeader(self, expr):
293294
static_import_decs = self._print(Declare(API_var, static=True))
294295
import_func = self._print(mod.import_func)
295296

297+
self._current_module = None
296298
header_id = f'{name.upper()}_WRAPPER'
297299
header_guard = f'{header_id}_H'
298300
return (f"#ifndef {header_guard}\n \
@@ -311,6 +313,7 @@ def _print_ModuleHeader(self, expr):
311313
def _print_PyModule(self, expr):
312314
scope = expr.scope
313315
self.set_scope(scope)
316+
self._current_module = expr
314317

315318
# Insert declared objects into scope
316319
variables = expr.original_module.variables if isinstance(expr, BindCModule) else expr.variables
@@ -322,7 +325,7 @@ def _print_PyModule(self, expr):
322325

323326
funcs = []
324327

325-
self._module_name = self.get_python_name(scope, expr)
328+
self._module_name = expr.name
326329
sep = self._print(SeparatorComment(40))
327330

328331
interface_funcs = [f.name for i in expr.interfaces for f in i.functions]
@@ -373,6 +376,7 @@ def _print_PyModule(self, expr):
373376
imports = ''.join(self._print(i) for i in imports)
374377

375378
self.exit_scope()
379+
self._current_module = None
376380

377381
return '\n'.join(['#define PY_ARRAY_UNIQUE_SYMBOL CWRAPPER_ARRAY_API',
378382
f'#define {pymod_name.upper()}\n',

pyccel/codegen/printing/fcode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def _print_Import(self, expr):
641641
return ''
642642

643643
if expr.source_module:
644-
source = expr.source_module.scope.get_expected_name(source)
644+
source = expr.source_module.name
645645

646646
if 'mpi4py' == str(getattr(expr.source,'name',expr.source)):
647647
return 'use mpi\n' + 'use mpiext\n'

pyccel/codegen/printing/pycode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,10 @@ def _print_Return(self, expr):
389389
def _print_Program(self, expr):
390390
mod_scope = self.scope
391391
self.set_scope(expr.scope)
392-
imports = ''.join(self._print(i) for i in expr.imports)
392+
modules = expr.get_direct_user_nodes(lambda m: isinstance(m, Module))
393+
assert len(modules) == 1
394+
module = modules[0]
395+
imports = ''.join(self._print(i) for i in expr.imports if i.source_module is not module)
393396
body = self._print(expr.body)
394397
imports += ''.join(self._print(i) for i in self.get_additional_imports())
395398

pyccel/codegen/python_wrapper.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ def create_shared_library(codegen,
170170
# Print code specific cwrapper
171171
#---------------------------------------
172172
module_old_name = codegen.ast.name
173-
codegen.ast.set_name(sharedlib_modname)
174173
wrapper_codegen = CWrapperCodePrinter(codegen.parser.filename, language)
175174
Scope.name_clash_checker = name_clash_checkers['c']
176175
wrapper = CToPythonWrapper(base_dirpath)
@@ -185,8 +184,6 @@ def create_shared_library(codegen,
185184
if errors.has_errors():
186185
return
187186

188-
codegen.ast.set_name(module_old_name)
189-
190187
with open(wrapper_filename, 'w', encoding="utf-8") as f:
191188
f.writelines(wrapper_code)
192189
timings['Wrapper printing'] = time.time() - start_print_cwrapper

pyccel/codegen/wrapper/c_to_python_wrapper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def _build_module_init_function(self, expr, imports):
597597
The initialisation function.
598598
"""
599599

600-
mod_name = getattr(expr, 'original_module', expr).name
600+
mod_name = self.scope.get_python_name(getattr(expr, 'original_module', expr).name)
601601
# Initialise the scope
602602
func_name = self.scope.get_new_name(f'PyInit_{mod_name}')
603603
func_scope = self.scope.new_child_scope(func_name)
@@ -633,7 +633,7 @@ def _build_module_init_function(self, expr, imports):
633633
ok_code = LiteralInteger(0)
634634

635635
# Save Capsule describing types (needed for dependent modules)
636-
body.append(AliasAssign(capsule_obj, PyCapsule_New(API_var, self.scope.get_python_name(mod_name))))
636+
body.append(AliasAssign(capsule_obj, PyCapsule_New(API_var, mod_name)))
637637
body.extend(self._add_object_to_mod(module_var, capsule_obj, '_C_API', initialised))
638638

639639
body.append(FunctionCall(import_array, ()))
@@ -1052,9 +1052,10 @@ def _wrap_Module(self, expr):
10521052

10531053
imports += cwrapper_ndarray_imports if self._wrapping_arrays else []
10541054
if not isinstance(expr, BindCModule):
1055-
imports.append(Import(expr.name, expr))
1055+
imports.append(Import(mod_scope.get_python_name(expr.name), expr))
10561056
original_mod = getattr(expr, 'original_module', expr)
1057-
return PyModule(original_mod.name, [API_var], funcs, imports = imports,
1057+
original_mod_name = mod_scope.get_python_name(original_mod.name)
1058+
return PyModule(original_mod_name, [API_var], funcs, imports = imports,
10581059
interfaces = interfaces, classes = classes, scope = mod_scope,
10591060
init_func = init_func, import_func = import_func)
10601061

@@ -2121,7 +2122,7 @@ def _wrap_Import(self, expr):
21212122

21222123
if import_wrapper:
21232124
wrapper_name = f'{expr.source}_wrapper'
2124-
mod_spoof = PyModule(expr.source_module.name.name, (), (), scope = Scope())
2125+
mod_spoof = PyModule(expr.source_module.name, (), (), scope = Scope())
21252126
return Import(wrapper_name, AsName(mod_spoof, expr.source), mod = mod_spoof)
21262127
else:
21272128
return None

pyccel/codegen/wrapper/fortran_to_c_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,10 @@ def _wrap_Module(self, expr):
190190
classes = [self._wrap(f) for f in expr.classes]
191191
variables = [self._wrap(v) for v in expr.variables if not v.is_private]
192192
variable_getters = [v for v in variables if isinstance(v, BindCArrayVariable)]
193-
imports = [Import(expr.name, target = expr, mod=expr)]
193+
imports = [Import(self.scope.get_python_name(expr.name), target = expr, mod=expr)]
194194

195-
name = mod_scope.get_new_name(f'bind_c_{expr.name.target}')
196-
self._wrapper_names_dict[expr.name.target] = name
195+
name = mod_scope.get_new_name(f'bind_c_{expr.name}')
196+
self._wrapper_names_dict[expr.name] = name
197197

198198
self.exit_scope()
199199

0 commit comments

Comments
 (0)