Skip to content

Commit 04b7ad4

Browse files
committed
[cdd/shared/parse/utils/parser_utils.py] Correctly use merge_params and fix its impl to return ; [cdd/sqlalchemy/utils/shared_utils.py] Prepare for increased test coverage ; [cdd/tests/{test_parse/test_parser_utils.py,test_sqlalchemy/test_emit_sqlalchemy_utils.py}] Increase test coverage ; [cdd/__init__.py] Bump version
1 parent 217849e commit 04b7ad4

File tree

5 files changed

+119
-60
lines changed

5 files changed

+119
-60
lines changed

cdd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from logging import getLogger as get_logger
1010

1111
__author__ = "Samuel Marks" # type: str
12-
__version__ = "0.0.99rc43" # type: str
12+
__version__ = "0.0.99rc44" # type: str
1313
__description__ = (
1414
"Open API to/fro routes, models, and tests. "
1515
"Convert between docstrings, classes, methods, argparse, pydantic, and SQLalchemy."

cdd/shared/parse/utils/parser_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,7 @@ def ir_merge(target, other):
6262
target["params"] = other["params"]
6363
elif other["params"]:
6464
target_params, other_params = map(itemgetter("params"), (target, other))
65-
66-
merge_params(other_params, target_params)
67-
68-
target["params"] = target_params
65+
target["params"] = merge_params(other_params, target_params)
6966

7067
if "return_type" not in (target.get("returns") or iter(())):
7168
target["returns"] = other["returns"]
@@ -110,6 +107,7 @@ def merge_params(other_params, target_params):
110107
merge_present_params(other_params[name], target_params[name])
111108
for name in other_params.keys() - target_params.keys():
112109
target_params[name] = other_params[name]
110+
return target_params
113111

114112

115113
def merge_present_params(other_param, target_param):

cdd/sqlalchemy/utils/shared_utils.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import ast
6-
from ast import Call, Expr, Load, Name, Subscript, Tuple, keyword
6+
from ast import Call, Expr, Load, Name, Subscript, Tuple, expr, keyword
77
from operator import attrgetter
88
from typing import Optional, cast
99

@@ -82,13 +82,14 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql):
8282
return _param.get("default") == cdd.shared.ast_utils.NoneStr, None
8383
elif _param["typ"].startswith("Optional["):
8484
_param["typ"] = _param["typ"][len("Optional[") : -1]
85-
nullable = True
85+
nullable: bool = True
8686
if "Literal[" in _param["typ"]:
8787
parsed_typ: Call = cast(
8888
Call, cdd.shared.ast_utils.get_value(ast.parse(_param["typ"]).body[0])
8989
)
90-
if parsed_typ.value.id != "Literal":
91-
return nullable, parsed_typ.value
90+
assert parsed_typ.value.id == "Literal", "Expected `Literal` got: {!r}".format(
91+
parsed_typ.value.id
92+
)
9293
val = cdd.shared.ast_utils.get_value(parsed_typ.slice)
9394
(
9495
args.append(
@@ -112,7 +113,7 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql):
112113
else _update_args_infer_typ_sqlalchemy_for_scalar(_param, args, x_typ_sql)
113114
)
114115
elif _param["typ"].startswith("List["):
115-
after_generic = _param["typ"][len("List[") :]
116+
after_generic: str = _param["typ"][len("List[") :]
116117
if "struct" in after_generic: # "," in after_generic or
117118
name: Name = Name(id="JSON", ctx=Load(), lineno=None, col_offset=None)
118119
else:
@@ -175,42 +176,53 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql):
175176
)
176177
)
177178
elif _param.get("typ").startswith("Union["):
178-
# Hack to remove the union type. Enum parse seems to be incorrect?
179-
union_typ: Subscript = cast(Subscript, ast.parse(_param["typ"]).body[0])
180-
assert isinstance(
181-
union_typ.value, Subscript
182-
), "Expected `Subscript` got `{type_name}`".format(
183-
type_name=type(union_typ.value).__name__
184-
)
185-
union_typ_tuple = (
186-
union_typ.value.slice if PY_GTE_3_9 else union_typ.value.slice.value
187-
)
188-
assert isinstance(
189-
union_typ_tuple, Tuple
190-
), "Expected `Tuple` got `{type_name}`".format(
191-
type_name=type(union_typ_tuple).__name__
192-
)
193-
assert (
194-
len(union_typ_tuple.elts) == 2
195-
), "Expected length of 2 got `{tuple_len}`".format(
196-
tuple_len=len(union_typ_tuple.elts)
197-
)
198-
left, right = map(attrgetter("id"), union_typ_tuple.elts)
199-
args.append(
200-
Name(
201-
(
202-
cdd.sqlalchemy.utils.emit_utils.typ2column_type[right]
203-
if right in cdd.sqlalchemy.utils.emit_utils.typ2column_type
204-
else cdd.sqlalchemy.utils.emit_utils.typ2column_type.get(left, left)
205-
),
206-
Load(),
207-
lineno=None,
208-
col_offset=None,
209-
)
210-
)
179+
args.append(_handle_union_of_length_2(_param["typ"]))
211180
else:
212181
_update_args_infer_typ_sqlalchemy_for_scalar(_param, args, x_typ_sql)
213182
return nullable, None
214183

215184

185+
def _handle_union_of_length_2(typ):
186+
"""
187+
Internal function to turn `str` to `Name`
188+
189+
:param typ: `str` which evaluates to `ast.Subscript`
190+
:type typ: ```str```
191+
192+
:return: Parsed out name
193+
:rtype: ```Name```
194+
"""
195+
# Hack to remove the union type. Enum parse seems to be incorrect?
196+
union_typ: Subscript = cast(Subscript, ast.parse(typ).body[0])
197+
assert isinstance(
198+
union_typ.value, Subscript
199+
), "Expected `Subscript` got `{type_name}`".format(
200+
type_name=type(union_typ.value).__name__
201+
)
202+
union_typ_tuple: expr = (
203+
union_typ.value.slice if PY_GTE_3_9 else union_typ.value.slice.value
204+
)
205+
assert isinstance(
206+
union_typ_tuple, Tuple
207+
), "Expected `Tuple` got `{type_name}`".format(
208+
type_name=type(union_typ_tuple).__name__
209+
)
210+
assert (
211+
len(union_typ_tuple.elts) == 2
212+
), "Expected length of 2 got `{tuple_len}`".format(
213+
tuple_len=len(union_typ_tuple.elts)
214+
)
215+
left, right = map(attrgetter("id"), union_typ_tuple.elts)
216+
return Name(
217+
(
218+
cdd.sqlalchemy.utils.emit_utils.typ2column_type[right]
219+
if right in cdd.sqlalchemy.utils.emit_utils.typ2column_type
220+
else cdd.sqlalchemy.utils.emit_utils.typ2column_type.get(left, left)
221+
),
222+
Load(),
223+
lineno=None,
224+
col_offset=None,
225+
)
226+
227+
216228
__all__ = ["update_args_infer_typ_sqlalchemy"]

cdd/tests/test_parse/test_parser_utils.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,23 @@
2727
class TestParserUtils(TestCase):
2828
"""Test class for parser_utils"""
2929

30+
def test_get_source_raises(self) -> None:
31+
"""Tests that `get_source` raises an exception"""
32+
with self.assertRaises(TypeError):
33+
get_source(None)
34+
35+
def raise_os_error(_):
36+
"""raise_OSError"""
37+
raise OSError
38+
39+
with patch("inspect.getsourcelines", raise_os_error), self.assertRaises(
40+
OSError
41+
):
42+
get_source(min)
43+
44+
with patch("inspect.getsourcefile", lambda _: None):
45+
self.assertIsNone(get_source(raise_os_error))
46+
3047
def test_ir_merge_empty(self) -> None:
3148
"""Tests for `ir_merge` when both are empty"""
3249
target = {"params": OrderedDict(), "returns": None}
@@ -250,22 +267,14 @@ def test_infer_raise(self) -> None:
250267
with self.assertRaises(NotImplementedError):
251268
cdd.shared.parse.utils.parser_utils.infer(None)
252269

253-
def test_get_source_raises(self) -> None:
254-
"""Tests that `get_source` raises an exception"""
255-
with self.assertRaises(TypeError):
256-
get_source(None)
257-
258-
def raise_os_error(_):
259-
"""raise_OSError"""
260-
raise OSError
261-
262-
with patch("inspect.getsourcelines", raise_os_error), self.assertRaises(
263-
OSError
264-
):
265-
get_source(min)
266-
267-
with patch("inspect.getsourcefile", lambda _: None):
268-
self.assertIsNone(get_source(raise_os_error))
270+
def test_merge_params(self) -> None:
271+
"""Tests `merge_params` works"""
272+
d0 = {"foo": "bar"}
273+
d1 = {"can": "haz"}
274+
self.assertDictEqual(
275+
cdd.shared.parse.utils.parser_utils.merge_params(deepcopy(d0), d1),
276+
{"foo": "bar", "can": "haz"},
277+
)
269278

270279

271280
unittest_main()

cdd/tests/test_sqlalchemy/test_emit_sqlalchemy_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ast
66
import json
77
from ast import (
8+
AST,
89
Assign,
910
Call,
1011
ClassDef,
@@ -19,8 +20,10 @@
1920
)
2021
from collections import OrderedDict
2122
from copy import deepcopy
23+
from functools import partial
2224
from os import mkdir, path
2325
from tempfile import TemporaryDirectory
26+
from typing import Callable, List, Optional, Tuple, Union
2427
from unittest import TestCase
2528
from unittest.mock import patch
2629

@@ -29,7 +32,10 @@
2932
from cdd.shared.ast_utils import set_value
3033
from cdd.shared.source_transformer import to_code
3134
from cdd.shared.types import IntermediateRepr
32-
from cdd.sqlalchemy.utils.shared_utils import update_args_infer_typ_sqlalchemy
35+
from cdd.sqlalchemy.utils.shared_utils import (
36+
_handle_union_of_length_2,
37+
update_args_infer_typ_sqlalchemy,
38+
)
3339
from cdd.tests.mocks.ir import (
3440
intermediate_repr_empty,
3541
intermediate_repr_no_default_doc,
@@ -296,6 +302,27 @@ def test_update_args_infer_typ_sqlalchemy_when_simple_array_in_typ(self) -> None
296302
# gold=Name(id="Small", ctx=Load(), lineno=None, col_offset=None),
297303
# )
298304

305+
def test_update_args_infer_typ_sqlalchemy_early_exit(self) -> None:
306+
"""Tests that `update_args_infer_typ_sqlalchemy` exits early"""
307+
_update_args_infer_typ_sqlalchemy: Callable[
308+
[dict], Tuple[bool, Optional[Union[List[AST], Tuple[AST]]]]
309+
] = partial(
310+
update_args_infer_typ_sqlalchemy,
311+
args=[],
312+
name="",
313+
nullable=True,
314+
x_typ_sql={},
315+
)
316+
self.assertTupleEqual(
317+
_update_args_infer_typ_sqlalchemy({"typ": None}), (False, None)
318+
)
319+
self.assertTupleEqual(
320+
_update_args_infer_typ_sqlalchemy(
321+
{"typ": None, "default": cdd.shared.ast_utils.NoneStr},
322+
),
323+
(True, None),
324+
)
325+
299326
def test_update_with_imports_from_columns(self) -> None:
300327
"""
301328
Tests basic `cdd.sqlalchemy.utils.emit_utils.update_with_imports_from_columns` usage
@@ -573,5 +600,18 @@ def test_rewrite_fk(self) -> None:
573600
gold=column_fk_gold,
574601
)
575602

603+
def test__handle_union_of_length_2(self) -> None:
604+
"""Tests that `_handle_union_of_length_2` works"""
605+
run_ast_test(
606+
self,
607+
gen_ast=_handle_union_of_length_2("Union[int, float]"),
608+
gold=Name(
609+
"Float",
610+
Load(),
611+
lineno=None,
612+
col_offset=None,
613+
),
614+
)
615+
576616

577617
unittest_main()

0 commit comments

Comments
 (0)