|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import List |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from jsonargparse import Namespace |
| 8 | +from jsonargparse._optionals import attrs_support |
| 9 | +from jsonargparse_tests.conftest import get_parser_help |
| 10 | + |
| 11 | +if attrs_support: |
| 12 | + import attrs |
| 13 | + |
| 14 | + @attrs.define |
| 15 | + class AttrsData: |
| 16 | + p1: float |
| 17 | + p2: str = "-" |
| 18 | + |
| 19 | + @attrs.define |
| 20 | + class AttrsSubData(AttrsData): |
| 21 | + p3: int = 3 |
| 22 | + |
| 23 | + @attrs.define |
| 24 | + class AttrsFieldFactory: |
| 25 | + p1: List[str] = attrs.field(factory=lambda: ["one", "two"]) |
| 26 | + |
| 27 | + @attrs.define |
| 28 | + class AttrsFieldInitFalse: |
| 29 | + p1: dict = attrs.field(init=False) |
| 30 | + |
| 31 | + def __attrs_post_init__(self): |
| 32 | + self.p1 = {} |
| 33 | + |
| 34 | + @attrs.define |
| 35 | + class AttrsSubField: |
| 36 | + p1: str = "-" |
| 37 | + p2: int = 0 |
| 38 | + |
| 39 | + @attrs.define |
| 40 | + class AttrsWithNestedDefaultDataclass: |
| 41 | + p1: float |
| 42 | + subfield: AttrsSubField = attrs.field(factory=AttrsSubField) |
| 43 | + |
| 44 | + @attrs.define |
| 45 | + class AttrsWithNestedDataclassNoDefault: |
| 46 | + p1: float |
| 47 | + subfield: AttrsSubField |
| 48 | + |
| 49 | + |
| 50 | +@pytest.mark.skipif(not attrs_support, reason="attrs package is required") |
| 51 | +class TestAttrs: |
| 52 | + def test_define(self, parser): |
| 53 | + parser.add_argument("--data", type=AttrsData) |
| 54 | + defaults = parser.get_defaults() |
| 55 | + assert Namespace(p1=None, p2="-") == defaults.data |
| 56 | + cfg = parser.parse_args(["--data.p1=0.2", "--data.p2=x"]) |
| 57 | + assert Namespace(p1=0.2, p2="x") == cfg.data |
| 58 | + |
| 59 | + def test_subclass(self, parser): |
| 60 | + parser.add_argument("--data", type=AttrsSubData) |
| 61 | + defaults = parser.get_defaults() |
| 62 | + assert Namespace(p1=None, p2="-", p3=3) == defaults.data |
| 63 | + |
| 64 | + def test_field_factory(self, parser): |
| 65 | + parser.add_argument("--data", type=AttrsFieldFactory) |
| 66 | + cfg1 = parser.parse_args([]) |
| 67 | + cfg2 = parser.parse_args([]) |
| 68 | + assert cfg1.data.p1 == ["one", "two"] |
| 69 | + assert cfg1.data.p1 == cfg2.data.p1 |
| 70 | + assert cfg1.data.p1 is not cfg2.data.p1 |
| 71 | + |
| 72 | + def test_field_init_false(self, parser): |
| 73 | + parser.add_argument("--data", type=AttrsFieldInitFalse) |
| 74 | + cfg = parser.parse_args([]) |
| 75 | + help_str = get_parser_help(parser) |
| 76 | + assert "--data.p1" not in help_str |
| 77 | + assert cfg == Namespace() |
| 78 | + init = parser.instantiate_classes(cfg) |
| 79 | + assert init.data.p1 == {} |
| 80 | + |
| 81 | + def test_nested_with_default(self, parser): |
| 82 | + parser.add_argument("--data", type=AttrsWithNestedDefaultDataclass) |
| 83 | + cfg = parser.parse_args(["--data.p1=1.23"]) |
| 84 | + assert cfg.data == Namespace(p1=1.23, subfield=Namespace(p1="-", p2=0)) |
| 85 | + |
| 86 | + def test_nested_without_default(self, parser): |
| 87 | + parser.add_argument("--data", type=AttrsWithNestedDataclassNoDefault) |
| 88 | + cfg = parser.parse_args(["--data.p1=1.23"]) |
| 89 | + assert cfg.data == Namespace(p1=1.23, subfield=Namespace(p1="-", p2=0)) |
0 commit comments