Skip to content

Commit 06c297a

Browse files
committed
add custom symbol table and check/transform omp directive calls
1 parent 1d90896 commit 06c297a

File tree

5 files changed

+284
-120
lines changed

5 files changed

+284
-120
lines changed

omp4py/core/preprocessor/obj2ast.py

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -23,73 +23,7 @@ def get_indentation(lines: list[str]) -> int:
2323
return 0
2424

2525

26-
class NamespaceVisitor(ast.NodeVisitor):
27-
namespace: int
28-
line: int
29-
name: str
30-
result: tuple[ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef, int] | None
31-
32-
def __init__(self, name: str, line: int) -> None:
33-
self.namespace = 0
34-
self.name = name
35-
self.line: int = line
36-
self.result = None
37-
38-
def search(self, module: ast.Module) -> tuple[ast.stmt, int]:
39-
self.generic_visit(module)
40-
if self.result is None:
41-
msg: str = "source code not found"
42-
raise ValueError(msg)
43-
return self.result
44-
45-
def visit_namespace(self, node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef) -> None:
46-
self.namespace += 1
47-
lineno: int = node.lineno - 1
48-
if len(node.decorator_list) > 0:
49-
lineno = node.decorator_list[0].lineno - 1
50-
if lineno == self.line and node.name == self.name:
51-
self.result = (node, self.namespace)
52-
53-
visit_FunctionDef = visit_namespace # noqa: N815
54-
visit_ClassDef = visit_namespace # noqa: N815
55-
visit_AsyncFunctionDef = visit_namespace # noqa: N815
56-
57-
58-
def check_object(obj: Callable[..., Any] | type):
59-
if not (inspect.isclass(obj) or inspect.iscoroutinefunction(obj) or inspect.isfunction(obj)):
60-
msg: str = "Invalid object type: only classes, functions, or async functions are allowed"
61-
raise ValueError(msg)
62-
qname: str = getattr(obj, "__qualname__", "")
63-
64-
if "." in qname:
65-
msg: str = "Decorator must be applied to an outer function or class"
66-
filename: str = inspect.getfile(obj)
67-
lines: list[str]
68-
start: int
69-
lines, start = inspect.findsource(obj)
70-
offset: int = len(lines[start]) - len(lines[start].lstrip())
71-
full_source: str = "".join(lines)
72-
raise syntax_error(msg, Span(start + 1, offset), full_source, filename)
73-
74-
75-
def from_object(obj: Callable[..., Any] | type) -> tuple[str, str, ast.Module, int]:
76-
check_object(obj)
77-
lines: list[str]
78-
start: int
79-
lines, start = inspect.findsource(obj)
80-
full_source: str = "".join(lines)
81-
filename: str = inspect.getfile(obj)
82-
module: ast.Module = ast.parse(full_source, filename)
83-
name: str = getattr(obj, "__name__", "")
84-
obj_smt: ast.stmt
85-
namespace: int
86-
obj_smt, namespace = NamespaceVisitor(name, start).search(module)
87-
module.body = [obj_smt]
88-
89-
return filename, full_source, module, namespace
90-
91-
92-
def from_object2(obj: Callable[..., Any] | type) -> tuple[str, str, ast.Module]:
26+
def from_object(obj: Callable[..., Any] | type) -> tuple[str, str, ast.Module]:
9327
lines: list[str]
9428
start: int
9529
lines, start = inspect.findsource(obj)

omp4py/core/preprocessor/preprocessor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ def process_object[T: Callable[..., Any] | type](arg: T, params: Params) -> T:
2323
filename: str
2424
data: str
2525
module: ast.Module
26-
namespace: int
27-
filename, data, module,namespace = obj2ast.from_object(arg)
28-
module: ast.Module = process(module, namespace,data, filename, params)
26+
filename, data, module = obj2ast.from_object(arg)
27+
module: ast.Module = process(module, False,data, filename, params)
2928
globals: ModuleType = sys.modules[arg.__module__]
3029

3130
# TODO: use importlib to use pyc cache
@@ -41,7 +40,7 @@ def process_object[T: Callable[..., Any] | type](arg: T, params: Params) -> T:
4140

4241

4342
def process_source(data: str, filename: str, params: Params) -> ast.Module:
44-
return process(ast.parse(data, filename),0, data, filename, params)
43+
return process(ast.parse(data, filename),True, data, filename, params)
4544

4645

4746
def process_file(filename: str, params: Params) -> str:
@@ -58,7 +57,7 @@ def process_file(filename: str, params: Params) -> str:
5857
return filename
5958

6059

61-
def process(module: ast.Module, namespace: int, full_source: str, filename: str, params: Params) -> ast.Module:
62-
transformer: OmpTransformer = OmpTransformer(full_source, filename, module, namespace, params)
60+
def process(module: ast.Module, is_module: bool, full_source: str, filename: str, params: Params) -> ast.Module:
61+
transformer: OmpTransformer = OmpTransformer(full_source, filename, module, is_module, params)
6362

6463
return transformer.transform()

omp4py/core/preprocessor/transformers/context.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1-
from omp4py.core.directive.schema import Directive
21
import ast
32
import os
4-
import symtable
53
from dataclasses import dataclass, field
4+
from symtable import symtable as native_symtable
65

7-
__all__ = ["Context", "Params"]
6+
from omp4py.core.parser import Directive
7+
from omp4py.core.preprocessor.transformers.symtable import SymbolTable
8+
9+
__all__ = ["Context", "Params", "SymbolTable", "global_symtable"]
10+
11+
12+
def global_symtable(data: str, filename: str) -> "SymbolTable":
13+
return SymbolTable(list(native_symtable(data, filename, "exec").get_identifiers()))
814

915

1016
def environ_bool(key: str, default: bool) -> bool:
@@ -24,19 +30,25 @@ class Context:
2430
full_source: str
2531
module: ast.Module
2632
namespace: int
27-
symtable_stack: list[symtable.SymbolTable]
33+
symtable: SymbolTable
2834
node_stack: list[ast.AST]
2935
decorator: ast.expr | None
3036
directive: Directive | None
3137

32-
def __init__(self, full_source: str, filename: str, module: ast.Module, namespace: int, params: Params) -> None:
38+
def __init__(
39+
self,
40+
full_source: str,
41+
filename: str,
42+
module: ast.Module,
43+
is_module: bool,
44+
params: Params,
45+
) -> None:
3346
self.params = params
3447
self.filename = filename
3548
self.module = module
36-
self.namespace = namespace
49+
self.is_module = is_module
3750
self.full_source = full_source
38-
self.symtable_stack = []
51+
self.symtable = SymbolTable()
3952
self.node_stack = [module]
4053
self.decorator = None
4154
self.directive = None
42-
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import ast
2+
import re
3+
from collections.abc import Callable, KeysView
4+
from dataclasses import dataclass
5+
from typing import Optional
6+
7+
PREFIX: str = "_omp_"
8+
9+
var_check: re.Pattern[str] = re.compile(rf"^({PREFIX})?([0-9]+)?(.+)$")
10+
11+
12+
def is_variable(name: str) -> bool:
13+
return not name.startswith(PREFIX) or name[len(PREFIX) : len(PREFIX) + 1].isdigit()
14+
15+
16+
@dataclass
17+
class SymbolEntry:
18+
scope_name: str
19+
old_name: str
20+
used: bool = False
21+
assigned: bool = False
22+
annotation: ast.expr | None = None
23+
24+
25+
class SymbolTableVisitor(ast.NodeVisitor):
26+
symbols: dict[str, SymbolEntry]
27+
check_namespace: bool
28+
to_rename: set[str]
29+
30+
def __init__(self, global_vars: list[str] | None):
31+
self.to_rename = set()
32+
self.check_namespace = False
33+
self.symbols = {}
34+
if global_vars is not None:
35+
name: str
36+
for name in global_vars:
37+
self.update_symbol(name, lambda x: None).assigned = True
38+
39+
def update_symbol(self, name: str, rename: Callable[[str], None]) -> SymbolEntry:
40+
match: re.Match[str] | None = var_check.match(name)
41+
if not match:
42+
return SymbolEntry("", "")
43+
omp: str | None = match.group(1)
44+
n: str | None = match.group(2)
45+
real_name: str = match.group(3)
46+
if omp is not None and n is None:
47+
return SymbolEntry("", "") # internal _omp_ var
48+
49+
symbol: SymbolEntry
50+
if real_name in self.symbols:
51+
symbol = self.symbols[real_name]
52+
else:
53+
old_name: str = real_name if n is None else (PREFIX + str(int(n) - 1) + real_name)
54+
symbol = self.symbols[real_name] = SymbolEntry(name, old_name)
55+
56+
if real_name in self.to_rename:
57+
new_name: str = PREFIX + str(1 if n is None else (int(n) + 1)) + real_name
58+
symbol.old_name = symbol.scope_name
59+
symbol.scope_name = new_name
60+
rename(new_name)
61+
return symbol
62+
63+
def check(self, node: ast.AST, namespace: bool = False) -> None:
64+
self.check_namespace = False
65+
self.visit(node)
66+
if namespace:
67+
self.check_namespace = namespace
68+
super().generic_visit(node)
69+
70+
def rename(self, names: set[str], node: ast.AST) -> None:
71+
self.to_rename = names
72+
self.visit(node)
73+
74+
def visit(self, node: ast.AST) -> None:
75+
super().visit(node)
76+
if len(self.to_rename) > 0 or (
77+
self.check_namespace and not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef))
78+
):
79+
super().generic_visit(node)
80+
81+
def generic_visit(self, node: ast.AST):
82+
pass
83+
84+
def visit_ClassDef(self, node: ast.ClassDef) -> None:
85+
self.update_symbol(node.name, lambda x: setattr(node, "name", x)).assigned = True
86+
87+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
88+
self.update_symbol(node.name, lambda x: setattr(node, "name", x)).assigned = True
89+
90+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
91+
self.update_symbol(node.name, lambda x: setattr(node, "name", x)).assigned = True
92+
93+
def visit_Import(self, node: ast.Import) -> None:
94+
alias: ast.alias
95+
for alias in node.names:
96+
name: str = alias.name if alias.asname is None else alias.asname
97+
self.update_symbol(name, lambda x: setattr(node, "asname", x)).assigned = True
98+
99+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
100+
alias: ast.alias
101+
for alias in node.names:
102+
name: str = alias.name if alias.asname is None else alias.asname
103+
self.update_symbol(name, lambda x: setattr(alias, "asname", x)).assigned = True # noqa: B023
104+
105+
def visit_Global(self, node: ast.Global) -> None:
106+
i: int
107+
name: str
108+
for i, name in enumerate(node.names):
109+
self.update_symbol(name, lambda x: node.names.__setitem__(i, x)).assigned = True # noqa: B023
110+
111+
def visit_Nonlocal(self, node: ast.Nonlocal) -> None:
112+
i: int
113+
name: str
114+
for i, name in enumerate(node.names):
115+
self.update_symbol(name, lambda x: node.names.__setitem__(i, x)).assigned = True # noqa: B023
116+
117+
def visit_Name(self, node: ast.Name) -> None:
118+
if isinstance(node.ctx, ast.Load):
119+
self.update_symbol(node.id, lambda x: setattr(node, "id", x)).used = True
120+
else:
121+
self.update_symbol(node.id, lambda x: setattr(node, "id", x)).assigned = True
122+
123+
def visit_arg(self, node: ast.arg) -> None:
124+
s: SymbolEntry = self.update_symbol(node.arg, lambda x: setattr(node, "arg", x))
125+
s.assigned = True
126+
s.annotation = node.annotation
127+
128+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
129+
if isinstance(node.target, ast.Name):
130+
self.update_symbol(node.target.id, lambda x: setattr(node.target, "id", x)).annotation = node.annotation
131+
132+
133+
class SymbolTable:
134+
_parent: Optional["SymbolTable"]
135+
_visitor: SymbolTableVisitor
136+
137+
def __init__(self, global_vars: list[str] | None = None):
138+
self._visitor = SymbolTableVisitor(global_vars)
139+
140+
def check_namespace(self, node: ast.AST) -> "SymbolTable":
141+
child: SymbolTable = self.new_child()
142+
child._visitor.check(node, namespace=True)
143+
return child
144+
145+
def update(self, node: ast.AST) -> None:
146+
self._visitor.check(node)
147+
148+
def new_child(self) -> "SymbolTable":
149+
child: SymbolTable = SymbolTable()
150+
child._parent = self
151+
return child
152+
153+
def rename(self, names: set[str], node: ast.AST) -> None:
154+
self._visitor.rename(names, node)
155+
156+
def symbols(self) -> list[SymbolEntry]:
157+
return list(self._visitor.symbols.values())
158+
159+
def identifiers(self) -> KeysView[str]:
160+
return self._visitor.symbols.keys()
161+
162+
def __getattr__(self, name: str) -> SymbolEntry:
163+
return self._visitor.symbols[name]
164+
165+
def __contains__(self, name: str) -> bool:
166+
return name in self._visitor.symbols

0 commit comments

Comments
 (0)