Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 5b58db2

Browse files
committed
base: Configurable implements args, config methods
Signed-off-by: John Andersen <[email protected]>
1 parent 4cdd4a3 commit 5b58db2

File tree

4 files changed

+287
-19
lines changed

4 files changed

+287
-19
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
- Features were moved from ModelContext to ModelConfig
1010
- CI is now run via GitHub Actions
1111
- CI testing script is now verbose
12+
- args and config methods of all classes no longer require implementation.
13+
BaseConfigurable handles exporting of arguments and creation of config objects
14+
for each class based off of the CONFIG property of that class. The CONFIG
15+
property is a class which has been decorated with dffml.base.config to make it
16+
a dataclass.
1217
### Fixed
1318
- DataFlows with multiple possibilities for a source for an input, now correctly
1419
look through all possible sources instead of just the first one.

dffml/base.py

Lines changed: 128 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,58 @@
44
"""
55
import abc
66
import inspect
7+
import argparse
8+
import contextlib
9+
import dataclasses
710
from argparse import ArgumentParser
8-
from typing import Dict, Any, Tuple, NamedTuple
11+
from typing import Dict, Any, Tuple, NamedTuple, Type
912

13+
try:
14+
from typing import get_origin, get_args
15+
except ImportError:
16+
# Added in Python 3.8
17+
def get_origin(t):
18+
return getattr(t, "__origin__", None)
19+
20+
def get_args(t):
21+
return getattr(t, "__args__", None)
22+
23+
24+
from .util.cli.arg import Arg
1025
from .util.data import traverse_config_set, traverse_config_get
1126

1227
from .util.entrypoint import Entrypoint
1328

1429
from .log import LOGGER
1530

1631

32+
class ParseExpandAction(argparse.Action):
33+
def __call__(self, parser, namespace, values, option_string=None):
34+
if not isinstance(values, list):
35+
values = [values]
36+
setattr(namespace, self.dest, self.LIST_CLS(*values))
37+
38+
39+
# Maps classes to their ParseClassNameAction
40+
LIST_ACTIONS: Dict[Type, Type] = {}
41+
42+
43+
def list_action(list_cls):
44+
"""
45+
Action to take a list of values and make them values in the list of type
46+
list_class. Which will be a class descendent from AsyncContextManagerList.
47+
"""
48+
LIST_ACTIONS.setdefault(
49+
list_cls,
50+
type(
51+
f"Parse{list_cls.__qualname__}Action",
52+
(ParseExpandAction,),
53+
{"LIST_CLS": list_cls},
54+
),
55+
)
56+
return LIST_ACTIONS[list_cls]
57+
58+
1759
class MissingArg(Exception):
1860
"""
1961
Raised when a BaseConfigurable is missing an argument from the args dict it
@@ -64,6 +106,13 @@ def __str__(self):
64106
return repr(self)
65107

66108

109+
def config(cls):
110+
"""
111+
Decorator to create a dataclass
112+
"""
113+
return dataclasses.dataclass(eq=True, frozen=True)(cls)
114+
115+
67116
class ConfigurableParsingNamespace(object):
68117
def __init__(self):
69118
self.dest = None
@@ -144,9 +193,14 @@ def config_get(cls, config, above, *path) -> BaseConfig:
144193
args_above = cls.add_orig_label() + list(path)
145194
label_above = cls.add_label(*above) + list(path)
146195
no_label_above = cls.add_label(*above)[:-1] + list(path)
196+
197+
arg = None
147198
try:
148199
arg = traverse_config_get(args, *args_above)
149200
except KeyError as error:
201+
pass
202+
203+
if arg is None:
150204
raise MissingArg(
151205
"Arg %r missing from %s%s%s"
152206
% (
@@ -155,23 +209,30 @@ def config_get(cls, config, above, *path) -> BaseConfig:
155209
"." if args_above[:-1] else "",
156210
".".join(args_above[:-1]),
157211
)
158-
) from error
159-
try:
212+
)
213+
214+
value = None
215+
# Try to get the value specific to this label
216+
with contextlib.suppress(KeyError):
160217
value = traverse_config_get(config, *label_above)
161-
except KeyError as error:
162-
try:
218+
219+
# Try to get the value specific to this plugin
220+
if value is None:
221+
with contextlib.suppress(KeyError):
163222
value = traverse_config_get(config, *no_label_above)
164-
except KeyError as error:
165-
if "default" in arg:
166-
return arg["default"]
167-
raise MissingConfig(
168-
"%s missing %r from %s"
169-
% (
170-
cls.__qualname__,
171-
label_above[-1],
172-
".".join(label_above[:-1]),
173-
)
174-
) from error
223+
224+
if value is None:
225+
# Return default if not found and available
226+
if "default" in arg:
227+
return arg["default"]
228+
raise MissingConfig(
229+
"%s missing %r from %s"
230+
% (
231+
cls.__qualname__,
232+
label_above[-1],
233+
".".join(label_above[:-1]),
234+
)
235+
)
175236

176237
if value is None and "default" in arg:
177238
return arg["default"]
@@ -197,19 +258,67 @@ def config_get(cls, config, above, *path) -> BaseConfig:
197258
return value
198259

199260
@classmethod
200-
@abc.abstractmethod
201-
def args(cls, *above) -> Dict[str, Any]:
261+
def args(cls, args, *above) -> Dict[str, Arg]:
202262
"""
203263
Return a dict containing arguments required for this class
204264
"""
265+
if getattr(cls, "CONFIG", None) is None:
266+
raise AttributeError(
267+
f"{cls.__qualname__} requires CONFIG property or implementation of args() classmethod"
268+
)
269+
for field in dataclasses.fields(cls.CONFIG):
270+
arg = Arg(type=field.type)
271+
# HACK For detecting dataclasses._MISSING_TYPE
272+
if "dataclasses._MISSING_TYPE" not in repr(field.default):
273+
arg["default"] = field.default
274+
if field.type == bool:
275+
arg["action"] = "store_true"
276+
elif inspect.isclass(field.type):
277+
if issubclass(field.type, list):
278+
arg["nargs"] = "+"
279+
if not hasattr(field.type, "SINGLETON"):
280+
raise AttributeError(
281+
f"{field.type.__qualname__} missing attribute SINGLETON"
282+
)
283+
arg["action"] = list_action(field.type)
284+
arg["type"] = field.type.SINGLETON
285+
if hasattr(arg["type"], "load"):
286+
# TODO (python3.8) Use Protocol
287+
arg["type"] = arg["type"].load
288+
elif get_origin(field.type) is list:
289+
arg["type"] = get_args(field.type)[0]
290+
arg["nargs"] = "+"
291+
if "help" in field.metadata:
292+
arg["help"] = field.metadata["help"]
293+
cls.config_set(args, above, field.name, arg)
294+
return args
205295

206296
@classmethod
207-
@abc.abstractmethod
208297
def config(cls, config, *above):
209298
"""
210299
Create the BaseConfig required to instantiate this class by parsing the
211300
config dict.
212301
"""
302+
if getattr(cls, "CONFIG", None) is None:
303+
raise AttributeError(
304+
f"{cls.__qualname__} requires CONFIG property or implementation of config() classmethod"
305+
)
306+
# Build the arguments to the CONFIG class
307+
kwargs: Dict[str, Any] = {}
308+
for field in dataclasses.fields(cls.CONFIG):
309+
kwargs[field.name] = got = cls.config_get(
310+
config, above, field.name
311+
)
312+
if inspect.isclass(got) and issubclass(got, BaseConfigurable):
313+
try:
314+
kwargs[field.name] = got.withconfig(
315+
config, *above, *cls.add_label()
316+
)
317+
except MissingConfig:
318+
kwargs[field.name] = got.withconfig(
319+
config, *above, *cls.add_label()[:-1]
320+
)
321+
return cls.CONFIG(**kwargs)
213322

214323
@classmethod
215324
def withconfig(cls, config, *above):

dffml/feature/feature.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def length(self) -> int:
301301
class Features(list):
302302

303303
TIMEOUT: int = 60 * 2
304+
SINGLETON = Feature
304305

305306
LOGGER = LOGGER.getChild("Features")
306307

tests/test_base.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import unittest
2+
from dataclasses import field
3+
from typing import List
4+
5+
from dffml.base import BaseDataFlowFacilitatorObject, config, list_action
6+
from dffml.feature.feature import DefFeature, Feature, Features
7+
from dffml.source.source import BaseSource
8+
from dffml.source.csv import CSVSource
9+
from dffml.source.json import JSONSource
10+
from dffml.util.entrypoint import entry_point, base_entry_point
11+
from dffml.util.cli.arg import Arg
12+
from dffml.util.cli.cmd import parse_unknown
13+
14+
15+
@config
16+
class FakeTestingConfig:
17+
files: List[str]
18+
features: Features
19+
name: str = field(metadata={"help": "Name of FakeTesting"})
20+
label: str = "unlabeled"
21+
readonly: bool = False
22+
source: BaseSource = JSONSource
23+
24+
25+
@base_entry_point("dffml.test", "test")
26+
class BaseTesting(BaseDataFlowFacilitatorObject):
27+
pass # pragma: no cov
28+
29+
30+
@entry_point("fake")
31+
class FakeTesting(BaseTesting):
32+
33+
CONFIG = FakeTestingConfig
34+
35+
36+
class TestAutoArgsConfig(unittest.TestCase):
37+
def test_00_args(self):
38+
self.maxDiff = 99999
39+
self.assertEqual(
40+
FakeTesting.args({}),
41+
{
42+
"test": {
43+
"arg": None,
44+
"config": {
45+
"fake": {
46+
"arg": None,
47+
"config": {
48+
"files": {
49+
"arg": Arg(type=str, nargs="+"),
50+
"config": {},
51+
},
52+
"features": {
53+
"arg": Arg(
54+
type=Feature.load,
55+
nargs="+",
56+
action=list_action(Features),
57+
),
58+
"config": {},
59+
},
60+
"name": {
61+
"arg": Arg(
62+
type=str, help="Name of FakeTesting"
63+
),
64+
"config": {},
65+
},
66+
"readonly": {
67+
"arg": Arg(
68+
type=bool,
69+
action="store_true",
70+
default=False,
71+
),
72+
"config": {},
73+
},
74+
"label": {
75+
"arg": Arg(type=str, default="unlabeled"),
76+
"config": {},
77+
},
78+
"source": {
79+
"arg": Arg(
80+
type=BaseSource.load,
81+
default=JSONSource,
82+
),
83+
"config": {},
84+
},
85+
},
86+
}
87+
},
88+
}
89+
},
90+
)
91+
92+
def test_config_defaults(self):
93+
config = FakeTesting.config(
94+
parse_unknown(
95+
"--test-fake-name",
96+
"feedface",
97+
"--test-files",
98+
"a",
99+
"b",
100+
"c",
101+
"--test-source-filename",
102+
"file.json",
103+
"--test-features",
104+
"def:Year:int:1",
105+
"def:Commits:int:10",
106+
)
107+
)
108+
self.assertEqual(config.files, ["a", "b", "c"])
109+
self.assertEqual(config.name, "feedface")
110+
self.assertEqual(config.label, "unlabeled")
111+
self.assertFalse(config.readonly)
112+
self.assertTrue(isinstance(config.source, JSONSource))
113+
self.assertEqual(config.source.config.filename, "file.json")
114+
self.assertEqual(
115+
config.features,
116+
Features(
117+
DefFeature("Year", int, 1), DefFeature("Commits", int, 10)
118+
),
119+
)
120+
121+
def test_config_set(self):
122+
config = FakeTesting.config(
123+
parse_unknown(
124+
"--test-fake-name",
125+
"feedface",
126+
"--test-fake-label",
127+
"default-label",
128+
"--test-fake-readonly",
129+
"--test-files",
130+
"a",
131+
"b",
132+
"c",
133+
"--test-fake-source",
134+
"csv",
135+
"--test-source-filename",
136+
"file.csv",
137+
"--test-features",
138+
"def:Year:int:1",
139+
"def:Commits:int:10",
140+
)
141+
)
142+
self.assertEqual(config.files, ["a", "b", "c"])
143+
self.assertEqual(config.name, "feedface")
144+
self.assertEqual(config.label, "default-label")
145+
self.assertTrue(config.readonly)
146+
self.assertTrue(isinstance(config.source, CSVSource))
147+
self.assertEqual(config.source.config.filename, "file.csv")
148+
self.assertEqual(
149+
config.features,
150+
Features(
151+
DefFeature("Year", int, 1), DefFeature("Commits", int, 10)
152+
),
153+
)

0 commit comments

Comments
 (0)