Skip to content

Commit 1abce70

Browse files
Refactored to decorator factory class
1 parent 52db713 commit 1abce70

File tree

1 file changed

+48
-21
lines changed

1 file changed

+48
-21
lines changed

src/autoencoder/training.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""All things related to training autoencoders."""
2+
from dataclasses import dataclass
23
from typing import Callable
4+
from typing import Optional
35
from typing import Tuple
46
from typing import Union
57

@@ -48,28 +50,53 @@ def anomaly_diff(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
4850
return optimized_func
4951

5052

51-
def build_encode_dim_loss_function(
52-
encode_dim: int,
53-
regularization_factor: float = 0.001,
54-
axis: Tuple[int, ...] = (1, 2, 3),
55-
) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
56-
"""Closure that sets up the custom encode dim loss function."""
57-
# calculate the encoding dim loss
58-
encode_dim_loss = encode_dim * regularization_factor
53+
@dataclass
54+
class build_encode_dim_loss_function: # noqa
55+
"""Decorator factory class to build a penalized loss function."""
5956

60-
# create function
61-
def penalize_encode_dimension(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
62-
"""Penalizes loss with additional encoding dimension value."""
63-
# calculate the dynamic mean reconstruction error on training data
64-
reconstruction_loss = tf.reduce_mean(tf.square(y_true - y_pred), axis=axis)
57+
encode_dim: int
58+
regularization_factor: float = 0.001
59+
axis: Tuple[int, ...] = (1, 2, 3)
6560

66-
# calculate penalized loss
67-
return reconstruction_loss + encode_dim_loss
61+
def __call__(
62+
self,
63+
penalized_loss: Optional[Callable[[tf.Tensor, int, float], tf.Tensor]] = None,
64+
) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
65+
"""Call decorator to build custom penalized loss function."""
66+
return self.decorate(penalized_loss)
6867

69-
# optimize with tf.function
70-
optimized_func: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] = tf.function(
71-
penalize_encode_dimension
72-
)
68+
@staticmethod
69+
def default_penalty(loss: tf.Tensor, encode: int, reg: float) -> tf.Tensor:
70+
"""Calculate the default penalty for the encoding dimension."""
71+
return loss + (loss * encode * reg)
7372

74-
# get wrapped function
75-
return optimized_func
73+
def decorate(
74+
self,
75+
penalized_loss: Optional[Callable[[tf.Tensor, int, float], tf.Tensor]] = None,
76+
) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
77+
"""Decorator that builds the complete penalized loss function."""
78+
# check for none
79+
if penalized_loss is None:
80+
# get default
81+
penalized_loss = self.default_penalty
82+
83+
# create function
84+
def custom_loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
85+
"""Calculate reconstruction error and apply penalty."""
86+
# calculate the dynamic mean reconstruction error on training data
87+
reconstruction_loss = tf.reduce_mean(
88+
tf.square(y_true - y_pred), axis=self.axis
89+
)
90+
91+
# calculate penalized loss
92+
return penalized_loss(
93+
reconstruction_loss, self.encode_dim, self.regularization_factor
94+
)
95+
96+
# optimize with tf.function
97+
optimized_func: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] = tf.function(
98+
custom_loss
99+
)
100+
101+
# get wrapped function
102+
return optimized_func

0 commit comments

Comments
 (0)