Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit db65fb9

Browse files
aghinsapdxjohnny
authored andcommitted
model: Use auto args and config
1 parent a7deb79 commit db65fb9

File tree

3 files changed

+26
-41
lines changed

3 files changed

+26
-41
lines changed

dffml/base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
they follow a similar API for instantiation and usage.
44
"""
55
import abc
6+
import copy
67
import inspect
78
import argparse
89
import contextlib
@@ -22,7 +23,7 @@ def get_args(t):
2223

2324

2425
from .util.cli.arg import Arg
25-
from .util.data import traverse_config_set, traverse_config_get
26+
from .util.data import traverse_config_set, traverse_config_get, type_lookup
2627

2728
from .util.entrypoint import Entrypoint
2829

@@ -114,6 +115,8 @@ def mkarg(field):
114115
# HACK For detecting dataclasses._MISSING_TYPE
115116
if "dataclasses._MISSING_TYPE" not in repr(field.default):
116117
arg["default"] = field.default
118+
if "dataclasses._MISSING_TYPE" not in repr(field.default_factory):
119+
arg["default"] = field.default_factory()
117120
if field.type == bool:
118121
arg["action"] = "store_true"
119122
elif inspect.isclass(field.type):
@@ -140,18 +143,20 @@ def convert_value(arg, value):
140143
if value is None:
141144
# Return default if not found and available
142145
if "default" in arg:
143-
return arg["default"]
146+
return copy.deepcopy(arg["default"])
144147
raise MissingConfig
145148

146-
# TODO This is a oversimplification of argparse's nargs
147149
if not "nargs" in arg:
148150
value = value[0]
149151
if "type" in arg:
152+
type_cls = arg["type"]
153+
if type_cls == Type:
154+
type_cls = type_lookup
150155
# TODO This is a oversimplification of argparse's nargs
151156
if "nargs" in arg:
152-
value = list(map(arg["type"], value))
157+
value = list(map(type_cls, value))
153158
else:
154-
value = arg["type"](value)
159+
value = type_cls(value)
155160
if "action" in arg:
156161
if isinstance(arg["action"], str):
157162
# HACK This accesses _pop_action_class from ArgumentParser
@@ -205,7 +210,7 @@ def config(cls):
205210
"""
206211
Decorator to create a dataclass
207212
"""
208-
datacls = dataclasses.dataclass(eq=True, init=True, frozen=True)(cls)
213+
datacls = dataclasses.dataclass(eq=True, init=True)(cls)
209214
datacls._fromdict = classmethod(_fromdict)
210215
datacls._replace = lambda self, *args, **kwargs: dataclasses.replace(
211216
self, *args, **kwargs

dffml/model/model.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import AsyncIterator, Tuple, Any, List, Optional, NamedTuple, Dict
1111

1212
from ..base import (
13+
config,
1314
BaseConfig,
1415
BaseDataFlowFacilitatorObjectContext,
1516
BaseDataFlowFacilitatorObject,
@@ -27,7 +28,8 @@ class ModelNotTrained(Exception):
2728
pass
2829

2930

30-
class ModelConfig(BaseConfig, NamedTuple):
31+
@config
32+
class ModelConfig:
3133
directory: str
3234
features: Features
3335

@@ -72,43 +74,12 @@ class Model(BaseDataFlowFacilitatorObject):
7274
various machine learning frameworks or concepts.
7375
"""
7476

77+
CONFIG = ModelConfig
78+
7579
def __call__(self) -> ModelContext:
7680
# If the config object for this model contains the directory property
7781
# then create it if it does not exist
7882
directory = getattr(self.config, "directory", None)
7983
if directory is not None and not os.path.isdir(directory):
8084
os.makedirs(directory)
8185
return self.CONTEXT(self)
82-
83-
@classmethod
84-
def args(cls, args, *above) -> Dict[str, Arg]:
85-
cls.config_set(
86-
args,
87-
above,
88-
"directory",
89-
Arg(
90-
default=os.path.join(
91-
os.path.expanduser("~"), ".cache", "dffml"
92-
)
93-
),
94-
)
95-
cls.config_set(
96-
args,
97-
above,
98-
"features",
99-
Arg(
100-
nargs="+",
101-
required=True,
102-
type=Feature.load,
103-
action=list_action(Features),
104-
),
105-
)
106-
107-
return args
108-
109-
@classmethod
110-
def config(cls, config, *above) -> BaseConfig:
111-
return ModelConfig(
112-
directory=cls.config_get(config, above, "directory"),
113-
features=cls.config_get(config, above, "features"),
114-
)

tests/test_cli.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from dffml.util.entrypoint import entry_point
3434
from dffml.util.asynctestcase import AsyncTestCase
3535
from dffml.util.cli.cmds import ModelCMD
36-
36+
from dffml.base import config
3737
from dffml.cli import Merge, Dataflow, Train, Accuracy, Predict, List
3838

3939
from .test_df import OPERATIONS, OPIMPS
@@ -94,6 +94,14 @@ def mktempfile(self):
9494
return self._stack.enter_context(non_existant_tempfile())
9595

9696

97+
@config
98+
class FakeConfig:
99+
features: Features
100+
directory: str = os.path.join(
101+
os.path.expanduser("~"), ".cache", "dffml", "test_cli", "fake"
102+
)
103+
104+
97105
class FakeFeature(Feature):
98106

99107
NAME: str = "fake"
@@ -134,6 +142,7 @@ async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
134142
class FakeModel(Model):
135143

136144
CONTEXT = FakeModelContext
145+
CONFIG = FakeConfig
137146

138147

139148
def feature_load(loading=None):

0 commit comments

Comments
 (0)