Skip to content

Commit f30633e

Browse files
committed
improve: replace class configs
1 parent aff0288 commit f30633e

File tree

3 files changed

+11
-27
lines changed

3 files changed

+11
-27
lines changed

lantern/early_stopping.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
from typing import Optional
2+
23
import torch.utils.tensorboard
34

45
from lantern import FunctionalBase
56

67

7-
class EarlyStopping(FunctionalBase):
8+
class EarlyStopping(FunctionalBase, arbitrary_types_allowed=True):
89
"""Keeps track of the best score and how long ago it was calculated."""
910

1011
tensorboard_logger: torch.utils.tensorboard.SummaryWriter
1112
best_score: Optional[float] = None
1213
scores_since_improvement: int = -1
1314

14-
class Config:
15-
arbitrary_types_allowed = True
16-
1715
def score(self, value):
1816
if self.best_score is None or value > self.best_score:
1917
return self.replace(

lantern/metric.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
import numpy as np
21
import functools
2+
from typing import Any, Callable, List, Optional, Union
3+
4+
import numpy as np
5+
36
from lantern import FunctionalBase, star
4-
from typing import Callable, Any, Optional, List, Union
57

68

79
class MapMetric(FunctionalBase):
810
map_fn: Optional[Callable[..., Any]]
911
state: List[Any]
1012

11-
class Config:
12-
arbitrary_types_allowed = True
13-
allow_mutation = True
14-
1513
def __init__(self, state=list(), map_fn=None):
1614
super().__init__(
1715
state=state,
@@ -101,14 +99,10 @@ def __iter__(self):
10199
Metric = MapMetric
102100

103101

104-
class ReduceMetric(FunctionalBase):
102+
class ReduceMetric(FunctionalBase, arbitrary_types_allowed=True):
105103
reduce_fn: Callable[..., Any]
106104
state: Any
107105

108-
class Config:
109-
arbitrary_types_allowed = True
110-
allow_mutation = True
111-
112106
def update_(self, *args, **kwargs):
113107
self.state = self.reduce_fn(self.state, *args, **kwargs)
114108
return self
@@ -141,10 +135,6 @@ class AggregateMetric(FunctionalBase):
141135
metric: Union[MapMetric, ReduceMetric]
142136
aggregate_fn: Callable
143137

144-
class Config:
145-
arbitrary_types_allowed = True
146-
allow_mutation = True
147-
148138
def map(self, fn):
149139
return self.replace(aggregate_fn=lambda state: fn(self.aggregate_fn(state)))
150140

lantern/metric_table.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
import textwrap
2+
from typing import Any, Dict, Union
3+
24
import pandas as pd
3-
from lantern import FunctionalBase
4-
from typing import Dict, Union, Any
55

6-
# from wire_damage.tools import MapMetric, ReduceMetric, AggregateMetric
6+
from lantern import FunctionalBase
77

88

9-
class MetricTable(FunctionalBase):
9+
class MetricTable(FunctionalBase, arbitrary_types_allowed=True):
1010
name: str
1111
metrics: Dict[str, Any]
12-
# metrics: Dict[str, Union[MapMetric, ReduceMetric, AggregateMetric]]
13-
14-
class Config:
15-
arbitrary_types_allowed = True
1612

1713
def __init__(self, name, metrics):
1814
super().__init__(

0 commit comments

Comments
 (0)