Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions semantikon/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ast
import inspect
import keyword
import re
import sys
import textwrap
Expand Down Expand Up @@ -157,10 +158,17 @@ def _resolve_annotation(annotation, func_globals=None):
return Annotated[undefined_name, args[1]]


def _is_valid_variable_name(name: str) -> bool:
return name.isidentifier() and not keyword.iskeyword(name)


def _to_tag(item: Any, count=None, must_be_named: bool = False) -> str:
if isinstance(item, ast.Name):
return item.id
elif must_be_named:
elif isinstance(item, ast.Constant):
if isinstance(item.value, str) and _is_valid_variable_name(item.value):
return item.value
if must_be_named:
raise NotAstNameError(
"With `must_be_named=True`, item must be captured in an `ast.Name` "
"variables, i.e only simple variable(-s) not containing any operation or "
Expand All @@ -172,11 +180,10 @@ def _to_tag(item: Any, count=None, must_be_named: bool = False) -> str:
return f"output_{count}"


def get_return_expressions(
func: Callable, separate_tuple: bool = True, strict: bool = False
) -> str | tuple[str, ...] | None:
source = inspect.getsource(func)
source = textwrap.dedent(source)
def _get_return_list(
func: Callable, strict: bool = False
) -> list[str | tuple[str, ...]]:
source = textwrap.dedent(inspect.getsource(func))
parsed = ast.parse(source)

func_node = next(n for n in parsed.body if isinstance(n, ast.FunctionDef))
Expand All @@ -197,9 +204,24 @@ def get_return_expressions(
]
)
)
elif isinstance(value, ast.Dict):
ret_list.append(
tuple(
[
_to_tag(k, ii, must_be_named=strict)
for ii, k in enumerate(value.keys)
]
)
)
else:
ret_list.append(_to_tag(value, must_be_named=strict))
return ret_list


def get_return_expressions(
func: Callable, separate_tuple: bool = True, strict: bool = False
) -> str | tuple[str, ...] | None:
ret_list = _get_return_list(func, strict=strict)
if len(ret_list) == 0 and not strict:
return None
elif len(set(ret_list)) == 1 and (
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,16 @@ def f(x):
self.assertEqual(get_return_expressions(f, separate_tuple=False), "None")
self.assertEqual(get_return_expressions(f, strict=True), "None")

def f(x):
return {"z": x}

self.assertEqual(get_return_expressions(f), ("z",))

def f(x):
return {"z + 1": x}

self.assertEqual(get_return_expressions(f), ("output_0",))

def test_get_return_labels(self):

def f(x):
Expand Down
Loading