Skip to content

Commit 866ab5b

Browse files
authored
35 Add tests for new optional (|) syntax (#36)
* add tests for new optional syntax (`|`) * ignore type warnings * ignore type error * check for `UnionType`
1 parent 4b473b9 commit 866ab5b

File tree

3 files changed

+157
-50
lines changed

3 files changed

+157
-50
lines changed

example_script.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ def raise_error():
139139
raise ValueError("Something went wrong!")
140140

141141

142+
def optional_value(value: Optional[int] = None):
143+
"""
144+
A command which accepts an optional value.
145+
"""
146+
if value:
147+
print(value + 1)
148+
else:
149+
print("Unknown")
150+
151+
142152
if __name__ == "__main__":
143153
cli = CLI()
144154
cli.register(say_hello)
@@ -152,4 +162,5 @@ def raise_error():
152162
cli.register(add, command_name="sum")
153163
cli.register(print_address)
154164
cli.register(raise_error)
165+
cli.register(optional_value)
155166
cli.run()

targ/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@
1414

1515
from .format import Color, format_text, get_underline
1616

17+
# Only available in Python 3.10 and above:
18+
try:
19+
from types import NoneType, UnionType # type: ignore
20+
except ImportError:
21+
NoneType = type(None) # type: ignore
22+
23+
class UnionType: # type: ignore
24+
pass
25+
26+
1727
__VERSION__ = "0.6.0"
1828

1929

@@ -210,10 +220,10 @@ def call_with(self, arg_class: Arguments):
210220

211221
if annotation in CONVERTABLE_TYPES:
212222
value = annotation(value)
213-
elif get_origin(annotation) is Union: # type: ignore
223+
elif get_origin(annotation) in [Union, UnionType]: # type: ignore
214224
# Union is used to detect Optional
215225
inner_annotations = get_args(annotation)
216-
filtered = [i for i in inner_annotations if i is not None]
226+
filtered = [i for i in inner_annotations if i is not NoneType]
217227
if len(filtered) == 1:
218228
annotation = filtered[0]
219229
if annotation in CONVERTABLE_TYPES:

tests/test_command.py

Lines changed: 134 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import dataclasses
22
import decimal
33
import sys
4-
from typing import Any, Optional
4+
from typing import Any, Optional, Union
55
from unittest import TestCase
66
from unittest.mock import MagicMock, patch
77

@@ -206,64 +206,89 @@ def test_command(arg1: bool = False):
206206
@patch("targ.CLI._get_cleaned_args")
207207
def test_optional_bool_arg(self, _get_cleaned_args: MagicMock):
208208
"""
209-
Test command arguments which are of type Optional[bool].
209+
Test command arguments which are optional booleans.
210210
"""
211211

212-
def test_command(arg1: Optional[bool] = None):
213-
"""
214-
A command for testing optional boolean arguments.
215-
"""
216-
if arg1 is None:
212+
def print_arg(arg):
213+
if arg is None:
217214
print("arg1 is None")
218-
elif arg1 is True:
215+
elif arg is True:
219216
print("arg1 is True")
220-
elif arg1 is False:
217+
elif arg is False:
221218
print("arg1 is False")
222219
else:
223220
raise ValueError("arg1 is the wrong type")
224221

225-
cli = CLI()
226-
cli.register(test_command)
222+
def test_optional(arg1: Optional[bool] = None):
223+
"""
224+
A command for testing `Optional[bool]` arguments.
225+
"""
226+
print_arg(arg1)
227227

228-
with patch("builtins.print", side_effect=print_) as print_mock:
228+
def test_union(arg1: Union[bool, None] = None):
229+
"""
230+
A command for testing `Union[bool, None]` arguments.
231+
"""
232+
print_arg(arg1)
229233

230-
configs: list[Config] = [
231-
Config(
232-
params=["test_command", "--arg1"],
233-
output="arg1 is True",
234-
),
235-
Config(
236-
params=["test_command", "--arg1=True"],
237-
output="arg1 is True",
238-
),
239-
Config(
240-
params=["test_command", "--arg1=true"],
241-
output="arg1 is True",
242-
),
243-
Config(
244-
params=["test_command", "--arg1=t"],
245-
output="arg1 is True",
246-
),
247-
Config(
248-
params=["test_command", "--arg1=False"],
249-
output="arg1 is False",
250-
),
251-
Config(
252-
params=["test_command", "--arg1=false"],
253-
output="arg1 is False",
254-
),
255-
Config(
256-
params=["test_command", "--arg1=f"],
257-
output="arg1 is False",
258-
),
259-
Config(params=["test_command"], output="arg1 is None"),
260-
]
234+
commands = [test_optional, test_union]
261235

262-
for config in configs:
263-
_get_cleaned_args.return_value = config.params
264-
cli.run()
265-
print_mock.assert_called_with(config.output)
266-
print_mock.reset_mock()
236+
if sys.version_info.major == 3 and sys.version_info.minor >= 10:
237+
238+
def test_union_syntax(arg1: bool | None = None): # type: ignore
239+
"""
240+
A command for testing `bool | None` arguments.
241+
"""
242+
print_arg(arg1)
243+
244+
commands.append(test_union_syntax)
245+
246+
cli = CLI()
247+
248+
for command in commands:
249+
cli.register(command)
250+
251+
with patch("builtins.print", side_effect=print_) as print_mock:
252+
for command in commands:
253+
command_name = command.__name__
254+
255+
configs: list[Config] = [
256+
Config(
257+
params=[command_name, "--arg1"],
258+
output="arg1 is True",
259+
),
260+
Config(
261+
params=[command_name, "--arg1=True"],
262+
output="arg1 is True",
263+
),
264+
Config(
265+
params=[command_name, "--arg1=true"],
266+
output="arg1 is True",
267+
),
268+
Config(
269+
params=[command_name, "--arg1=t"],
270+
output="arg1 is True",
271+
),
272+
Config(
273+
params=[command_name, "--arg1=False"],
274+
output="arg1 is False",
275+
),
276+
Config(
277+
params=[command_name, "--arg1=false"],
278+
output="arg1 is False",
279+
),
280+
Config(
281+
params=[command_name, "--arg1=f"],
282+
output="arg1 is False",
283+
),
284+
Config(params=[command_name], output="arg1 is None"),
285+
]
286+
287+
for config in configs:
288+
_get_cleaned_args.return_value = config.params
289+
cli.run()
290+
print_mock.assert_called_with(config.output)
291+
print_mock.reset_mock()
267292

268293
@patch("targ.CLI._get_cleaned_args")
269294
def test_int_arg(self, _get_cleaned_args: MagicMock):
@@ -302,6 +327,67 @@ def test_command(arg1: decimal.Decimal):
302327
print_mock.assert_called_with(config.output)
303328
print_mock.reset_mock()
304329

330+
@patch("targ.CLI._get_cleaned_args")
331+
def test_optional_int_arg(self, _get_cleaned_args: MagicMock):
332+
"""
333+
Test command arguments which are optional int.
334+
"""
335+
336+
def print_arg(arg):
337+
if arg is None:
338+
print("arg1 is None")
339+
elif isinstance(arg, int):
340+
print("arg1 is an int")
341+
else:
342+
raise ValueError("arg1 is the wrong type")
343+
344+
def test_optional(arg1: Optional[int] = None):
345+
"""
346+
A command for testing `Optional[int]` arguments.
347+
"""
348+
print_arg(arg1)
349+
350+
def test_union(arg1: Union[int, None] = None):
351+
"""
352+
A command for testing `Union[int, None]` arguments.
353+
"""
354+
print_arg(arg1)
355+
356+
commands = [test_optional, test_union]
357+
358+
if sys.version_info.major == 3 and sys.version_info.minor >= 10:
359+
360+
def test_union_syntax(arg1: int | None = None): # type: ignore
361+
"""
362+
A command for testing `int | None` arguments.
363+
"""
364+
print_arg(arg1)
365+
366+
commands.append(test_union_syntax)
367+
368+
cli = CLI()
369+
370+
for command in commands:
371+
cli.register(command)
372+
373+
with patch("builtins.print", side_effect=print_) as print_mock:
374+
for command in commands:
375+
command_name = command.__name__
376+
377+
configs: list[Config] = [
378+
Config(
379+
params=[command_name, "--arg1=1"],
380+
output="arg1 is an int",
381+
),
382+
Config(params=[command_name], output="arg1 is None"),
383+
]
384+
385+
for config in configs:
386+
_get_cleaned_args.return_value = config.params
387+
cli.run()
388+
print_mock.assert_called_with(config.output)
389+
print_mock.reset_mock()
390+
305391
@patch("targ.CLI._get_cleaned_args")
306392
def test_decimal_arg(self, _get_cleaned_args: MagicMock):
307393
"""

0 commit comments

Comments
 (0)