Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 457b604

Browse files
aghinsapdxjohnny
authored andcommitted
model: scratch: Use auto args and config
1 parent 05b0450 commit 457b604

File tree

1 file changed

+12
-47
lines changed
  • model/scratch/dffml_model_scratch

1 file changed

+12
-47
lines changed

model/scratch/dffml_model_scratch/slr.py

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313

1414
from dffml.repo import Repo
15+
from dffml.base import config, field
1516
from dffml.source.source import Sources
1617
from dffml.feature import Features
1718
from dffml.accuracy import Accuracy
@@ -22,10 +23,16 @@
2223
from dffml.util.cli.parser import list_action
2324

2425

25-
class SLRConfig(ModelConfig, NamedTuple):
26-
predict: str
27-
directory: str
28-
features: Features
26+
@config
27+
class SLRConfig:
28+
predict: str = field("Label or the value to be predicted")
29+
features: Features = field("Features to train on")
30+
directory: str = field(
31+
"Directory where state should be saved",
32+
default=os.path.join(
33+
os.path.expanduser("~"), ".cache", "dffml", "scratch"
34+
),
35+
)
2936

3037

3138
class SLRContext(ModelContext):
@@ -193,6 +200,7 @@ class SLR(Model):
193200
"""
194201

195202
CONTEXT = SLRContext
203+
CONFIG = SLRConfig
196204

197205
def __init__(self, config: SLRConfig) -> None:
198206
super().__init__(config)
@@ -215,46 +223,3 @@ async def __aexit__(self, exc_type, exc_value, traceback):
215223
filename = self._filename()
216224
with open(filename, "w") as write:
217225
json.dump(self.saved, write)
218-
219-
@classmethod
220-
def args(cls, args, *above) -> Dict[str, Arg]:
221-
cls.config_set(
222-
args,
223-
above,
224-
"directory",
225-
Arg(
226-
default=os.path.join(
227-
os.path.expanduser("~"), ".cache", "dffml", "scratch"
228-
),
229-
help="Directory where state should be saved",
230-
),
231-
)
232-
233-
cls.config_set(
234-
args,
235-
above,
236-
"predict",
237-
Arg(type=str, help="Label or the value to be predicted"),
238-
)
239-
240-
cls.config_set(
241-
args,
242-
above,
243-
"features",
244-
Arg(
245-
nargs="+",
246-
required=True,
247-
type=Feature.load,
248-
action=list_action(Features),
249-
help="Features to train on",
250-
),
251-
)
252-
return args
253-
254-
@classmethod
255-
def config(cls, config, *above) -> "SLRConfig":
256-
return SLRConfig(
257-
directory=cls.config_get(config, above, "directory"),
258-
predict=cls.config_get(config, above, "predict"),
259-
features=cls.config_get(config, above, "features"),
260-
)

0 commit comments

Comments
 (0)