Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 05b0450

Browse files
aghinsapdxjohnny
authored andcommitted
model: tensorflow: Use auto args and config
1 parent db65fb9 commit 05b0450

File tree

2 files changed

+40
-187
lines changed

2 files changed

+40
-187
lines changed

model/tensorflow/dffml_model_tensorflow/dnnc.py

Lines changed: 21 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pydoc
88
import hashlib
99
import inspect
10-
from dataclasses import dataclass
1110
from typing import List, Dict, Any, AsyncIterator, Tuple, Optional, Type
1211

1312
import numpy as np
@@ -23,6 +22,7 @@
2322
from dffml.util.cli.arg import Arg
2423
from dffml.feature.feature import Feature, Features
2524
from dffml.util.cli.parser import list_action
25+
from dffml.base import config, field
2626

2727

2828
class TensorflowModelContext(ModelContext):
@@ -137,16 +137,26 @@ def model(self):
137137
"""
138138

139139

140-
@dataclass(init=True, eq=True)
140+
@config
141141
class DNNClassifierModelConfig:
142-
directory: str
143-
steps: int
144-
epochs: int
145-
hidden: List[int]
146-
classification: str
147-
classifications: List[str]
148-
clstype: Type
149-
features: Features
142+
classification: str = field("Feature name holding classification value")
143+
classifications: List[str] = field("Options for value of classification")
144+
features: Features = field("Features to train on")
145+
clstype: Type = field("Data type of classifications values", default=str)
146+
steps: int = field("Number of steps to train the model", default=3000)
147+
epochs: int = field(
148+
"Number of iterations to pass over all repos in a source", default=30
149+
)
150+
directory: str = field(
151+
"Directory where state should be saved",
152+
default=os.path.join(
153+
os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
154+
),
155+
)
156+
hidden: List[int] = field(
157+
"List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",
158+
default_factory=lambda: [12, 40, 15],
159+
)
150160

151161
def __post_init__(self):
152162
self.classifications = list(map(self.clstype, self.classifications))
@@ -421,96 +431,4 @@ class DNNClassifierModel(Model):
421431
"""
422432

423433
CONTEXT = DNNClassifierModelContext
424-
425-
@classmethod
426-
def args(cls, args, *above) -> Dict[str, Arg]:
427-
cls.config_set(
428-
args,
429-
above,
430-
"directory",
431-
Arg(
432-
default=os.path.join(
433-
os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
434-
),
435-
help="Directory where state should be saved",
436-
),
437-
)
438-
cls.config_set(
439-
args,
440-
above,
441-
"steps",
442-
Arg(
443-
type=int,
444-
default=3000,
445-
help="Number of steps to train the model",
446-
),
447-
)
448-
cls.config_set(
449-
args,
450-
above,
451-
"epochs",
452-
Arg(
453-
type=int,
454-
default=30,
455-
help="Number of iterations to pass over all repos in a source",
456-
),
457-
)
458-
cls.config_set(
459-
args,
460-
above,
461-
"hidden",
462-
Arg(
463-
type=int,
464-
nargs="+",
465-
default=[12, 40, 15],
466-
help="List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",
467-
),
468-
)
469-
cls.config_set(
470-
args,
471-
above,
472-
"classification",
473-
Arg(help="Feature name holding classification value"),
474-
)
475-
cls.config_set(
476-
args,
477-
above,
478-
"classifications",
479-
Arg(nargs="+", help="Options for value of classification"),
480-
)
481-
cls.config_set(
482-
args,
483-
above,
484-
"clstype",
485-
Arg(
486-
type=pydoc.locate,
487-
default=str,
488-
help="Data type of classifications values (default: str)",
489-
),
490-
)
491-
cls.config_set(
492-
args,
493-
above,
494-
"features",
495-
Arg(
496-
nargs="+",
497-
required=True,
498-
type=Feature.load,
499-
action=list_action(Features),
500-
help="Features to train on",
501-
),
502-
)
503-
return args
504-
505-
@classmethod
506-
def config(cls, config, *above) -> BaseConfig:
507-
return DNNClassifierModelConfig(
508-
directory=cls.config_get(config, above, "directory"),
509-
steps=cls.config_get(config, above, "steps"),
510-
epochs=cls.config_get(config, above, "epochs"),
511-
hidden=cls.config_get(config, above, "hidden"),
512-
classification=cls.config_get(config, above, "classification"),
513-
classifications=cls.config_get(config, above, "classifications"),
514-
clstype=cls.config_get(config, above, "clstype"),
515-
features=cls.config_get(config, above, "features"),
516-
)
434+
CONFIG = DNNClassifierModelConfig

model/tensorflow/dffml_model_tensorflow/dnnr.py

Lines changed: 19 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
repo.
44
"""
55
import os
6-
from dataclasses import dataclass
76
from typing import List, Dict, Any, AsyncIterator
87

98
import numpy as np
@@ -14,22 +13,32 @@
1413
from dffml.model.model import Model
1514
from dffml.accuracy import Accuracy
1615
from dffml.util.entrypoint import entry_point
17-
from dffml.base import BaseConfig
16+
from dffml.base import BaseConfig, config, field
1817
from dffml.util.cli.arg import Arg
1918
from dffml.feature.feature import Feature, Features
2019
from dffml.util.cli.parser import list_action
2120

2221
from dffml_model_tensorflow.dnnc import TensorflowModelContext
2322

2423

25-
@dataclass(init=True, eq=True)
24+
@config
2625
class DNNRegressionModelConfig:
27-
directory: str
28-
steps: int
29-
epochs: int
30-
hidden: List[int]
31-
predict: str # feature_name holding target values
32-
features: Features
26+
predict: str = field("Feature name holding target values")
27+
features: Features = field("Features to train on")
28+
steps: int = field("Number of steps to train the model", default=3000)
29+
epochs: int = field(
30+
"Number of iterations to pass over all repos in a source", default=30
31+
)
32+
directory: str = field(
33+
"Directory where state should be saved",
34+
default=os.path.join(
35+
os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
36+
),
37+
)
38+
hidden: List[int] = field(
39+
"List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",
40+
default_factory=lambda: [12, 40, 15],
41+
)
3342

3443

3544
class DNNRegressionModelContext(TensorflowModelContext):
@@ -264,78 +273,4 @@ class DNNRegressionModel(Model):
264273
"""
265274

266275
CONTEXT = DNNRegressionModelContext
267-
268-
@classmethod
269-
def args(cls, args, *above) -> Dict[str, Arg]:
270-
cls.config_set(
271-
args,
272-
above,
273-
"directory",
274-
Arg(
275-
default=os.path.join(
276-
os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
277-
),
278-
help="Directory where state should be saved",
279-
),
280-
)
281-
cls.config_set(
282-
args,
283-
above,
284-
"steps",
285-
Arg(
286-
type=int,
287-
default=3000,
288-
help="Number of steps to train the model",
289-
),
290-
)
291-
cls.config_set(
292-
args,
293-
above,
294-
"epochs",
295-
Arg(
296-
type=int,
297-
default=30,
298-
help="Number of iterations to pass over all repos in a source",
299-
),
300-
)
301-
cls.config_set(
302-
args,
303-
above,
304-
"hidden",
305-
Arg(
306-
type=int,
307-
nargs="+",
308-
default=[12, 40, 15],
309-
help="List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",
310-
),
311-
)
312-
cls.config_set(
313-
args,
314-
above,
315-
"predict",
316-
Arg(help="Feature name holding truth value"),
317-
)
318-
cls.config_set(
319-
args,
320-
above,
321-
"features",
322-
Arg(
323-
nargs="+",
324-
required=True,
325-
type=Feature.load,
326-
action=list_action(Features),
327-
help="Features to train on",
328-
),
329-
)
330-
return args
331-
332-
@classmethod
333-
def config(cls, config, *above) -> BaseConfig:
334-
return DNNRegressionModelConfig(
335-
directory=cls.config_get(config, above, "directory"),
336-
steps=cls.config_get(config, above, "steps"),
337-
epochs=cls.config_get(config, above, "epochs"),
338-
hidden=cls.config_get(config, above, "hidden"),
339-
predict=cls.config_get(config, above, "predict"),
340-
features=cls.config_get(config, above, "features"),
341-
)
276+
CONFIG = DNNRegressionModelConfig

0 commit comments

Comments
 (0)