|
7 | 7 | import pydoc |
8 | 8 | import hashlib |
9 | 9 | import inspect |
10 | | -from dataclasses import dataclass |
11 | 10 | from typing import List, Dict, Any, AsyncIterator, Tuple, Optional, Type |
12 | 11 |
|
13 | 12 | import numpy as np |
|
23 | 22 | from dffml.util.cli.arg import Arg |
24 | 23 | from dffml.feature.feature import Feature, Features |
25 | 24 | from dffml.util.cli.parser import list_action |
| 25 | +from dffml.base import config, field |
26 | 26 |
|
27 | 27 |
|
28 | 28 | class TensorflowModelContext(ModelContext): |
@@ -137,16 +137,26 @@ def model(self): |
137 | 137 | """ |
138 | 138 |
|
139 | 139 |
|
140 | | -@dataclass(init=True, eq=True) |
| 140 | +@config |
141 | 141 | class DNNClassifierModelConfig: |
142 | | - directory: str |
143 | | - steps: int |
144 | | - epochs: int |
145 | | - hidden: List[int] |
146 | | - classification: str |
147 | | - classifications: List[str] |
148 | | - clstype: Type |
149 | | - features: Features |
| 142 | + classification: str = field("Feature name holding classification value") |
| 143 | + classifications: List[str] = field("Options for value of classification") |
| 144 | + features: Features = field("Features to train on") |
| 145 | + clstype: Type = field("Data type of classifications values", default=str) |
| 146 | + steps: int = field("Number of steps to train the model", default=3000) |
| 147 | + epochs: int = field( |
| 148 | + "Number of iterations to pass over all repos in a source", default=30 |
| 149 | + ) |
| 150 | + directory: str = field( |
| 151 | + "Directory where state should be saved", |
| 152 | + default=os.path.join( |
| 153 | + os.path.expanduser("~"), ".cache", "dffml", "tensorflow" |
| 154 | + ), |
| 155 | + ) |
| 156 | + hidden: List[int] = field( |
| 157 | + "List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer", |
| 158 | + default_factory=lambda: [12, 40, 15], |
| 159 | + ) |
150 | 160 |
|
151 | 161 | def __post_init__(self): |
152 | 162 | self.classifications = list(map(self.clstype, self.classifications)) |
@@ -421,96 +431,4 @@ class DNNClassifierModel(Model): |
421 | 431 | """ |
422 | 432 |
|
423 | 433 | CONTEXT = DNNClassifierModelContext |
424 | | - |
425 | | - @classmethod |
426 | | - def args(cls, args, *above) -> Dict[str, Arg]: |
427 | | - cls.config_set( |
428 | | - args, |
429 | | - above, |
430 | | - "directory", |
431 | | - Arg( |
432 | | - default=os.path.join( |
433 | | - os.path.expanduser("~"), ".cache", "dffml", "tensorflow" |
434 | | - ), |
435 | | - help="Directory where state should be saved", |
436 | | - ), |
437 | | - ) |
438 | | - cls.config_set( |
439 | | - args, |
440 | | - above, |
441 | | - "steps", |
442 | | - Arg( |
443 | | - type=int, |
444 | | - default=3000, |
445 | | - help="Number of steps to train the model", |
446 | | - ), |
447 | | - ) |
448 | | - cls.config_set( |
449 | | - args, |
450 | | - above, |
451 | | - "epochs", |
452 | | - Arg( |
453 | | - type=int, |
454 | | - default=30, |
455 | | - help="Number of iterations to pass over all repos in a source", |
456 | | - ), |
457 | | - ) |
458 | | - cls.config_set( |
459 | | - args, |
460 | | - above, |
461 | | - "hidden", |
462 | | - Arg( |
463 | | - type=int, |
464 | | - nargs="+", |
465 | | - default=[12, 40, 15], |
466 | | - help="List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer", |
467 | | - ), |
468 | | - ) |
469 | | - cls.config_set( |
470 | | - args, |
471 | | - above, |
472 | | - "classification", |
473 | | - Arg(help="Feature name holding classification value"), |
474 | | - ) |
475 | | - cls.config_set( |
476 | | - args, |
477 | | - above, |
478 | | - "classifications", |
479 | | - Arg(nargs="+", help="Options for value of classification"), |
480 | | - ) |
481 | | - cls.config_set( |
482 | | - args, |
483 | | - above, |
484 | | - "clstype", |
485 | | - Arg( |
486 | | - type=pydoc.locate, |
487 | | - default=str, |
488 | | - help="Data type of classifications values (default: str)", |
489 | | - ), |
490 | | - ) |
491 | | - cls.config_set( |
492 | | - args, |
493 | | - above, |
494 | | - "features", |
495 | | - Arg( |
496 | | - nargs="+", |
497 | | - required=True, |
498 | | - type=Feature.load, |
499 | | - action=list_action(Features), |
500 | | - help="Features to train on", |
501 | | - ), |
502 | | - ) |
503 | | - return args |
504 | | - |
505 | | - @classmethod |
506 | | - def config(cls, config, *above) -> BaseConfig: |
507 | | - return DNNClassifierModelConfig( |
508 | | - directory=cls.config_get(config, above, "directory"), |
509 | | - steps=cls.config_get(config, above, "steps"), |
510 | | - epochs=cls.config_get(config, above, "epochs"), |
511 | | - hidden=cls.config_get(config, above, "hidden"), |
512 | | - classification=cls.config_get(config, above, "classification"), |
513 | | - classifications=cls.config_get(config, above, "classifications"), |
514 | | - clstype=cls.config_get(config, above, "clstype"), |
515 | | - features=cls.config_get(config, above, "features"), |
516 | | - ) |
| 434 | + CONFIG = DNNClassifierModelConfig |
0 commit comments