|
1 | 1 | """All things related to training autoencoders.""" |
| 2 | +from dataclasses import dataclass |
2 | 3 | from typing import Callable |
| 4 | +from typing import Optional |
3 | 5 | from typing import Tuple |
4 | 6 | from typing import Union |
5 | 7 |
|
@@ -48,28 +50,53 @@ def anomaly_diff(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: |
48 | 50 | return optimized_func |
49 | 51 |
|
50 | 52 |
|
51 | | -def build_encode_dim_loss_function( |
52 | | - encode_dim: int, |
53 | | - regularization_factor: float = 0.001, |
54 | | - axis: Tuple[int, ...] = (1, 2, 3), |
55 | | -) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: |
56 | | - """Closure that sets up the custom encode dim loss function.""" |
57 | | - # calculate the encoding dim loss |
58 | | - encode_dim_loss = encode_dim * regularization_factor |
| 53 | +@dataclass |
| 54 | +class build_encode_dim_loss_function: # noqa |
| 55 | + """Decorator factory class to build a penalized loss function.""" |
59 | 56 |
|
60 | | - # create function |
61 | | - def penalize_encode_dimension(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: |
62 | | - """Penalizes loss with additional encoding dimension value.""" |
63 | | - # calculate the dynamic mean reconstruction error on training data |
64 | | - reconstruction_loss = tf.reduce_mean(tf.square(y_true - y_pred), axis=axis) |
| 57 | + encode_dim: int |
| 58 | + regularization_factor: float = 0.001 |
| 59 | + axis: Tuple[int, ...] = (1, 2, 3) |
65 | 60 |
|
66 | | - # calculate penalized loss |
67 | | - return reconstruction_loss + encode_dim_loss |
| 61 | + def __call__( |
| 62 | + self, |
| 63 | + penalized_loss: Optional[Callable[[tf.Tensor, int, float], tf.Tensor]] = None, |
| 64 | + ) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: |
| 65 | + """Call decorator to build custom penalized loss function.""" |
| 66 | + return self.decorate(penalized_loss) |
68 | 67 |
|
69 | | - # optimize with tf.function |
70 | | - optimized_func: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] = tf.function( |
71 | | - penalize_encode_dimension |
72 | | - ) |
| 68 | + @staticmethod |
| 69 | + def default_penalty(loss: tf.Tensor, encode: int, reg: float) -> tf.Tensor: |
| 70 | + """Calculate the default penalty for the encoding dimension.""" |
| 71 | + return loss + (loss * encode * reg) |
73 | 72 |
|
74 | | - # get wrapped function |
75 | | - return optimized_func |
| 73 | + def decorate( |
| 74 | + self, |
| 75 | + penalized_loss: Optional[Callable[[tf.Tensor, int, float], tf.Tensor]] = None, |
| 76 | + ) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: |
| 77 | + """Decorator that builds the complete penalized loss function.""" |
| 78 | + # check for none |
| 79 | + if penalized_loss is None: |
| 80 | + # get default |
| 81 | + penalized_loss = self.default_penalty |
| 82 | + |
| 83 | + # create function |
| 84 | + def custom_loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: |
| 85 | + """Calculate reconstruction error and apply penalty.""" |
| 86 | + # calculate the dynamic mean reconstruction error on training data |
| 87 | + reconstruction_loss = tf.reduce_mean( |
| 88 | + tf.square(y_true - y_pred), axis=self.axis |
| 89 | + ) |
| 90 | + |
| 91 | + # calculate penalized loss |
| 92 | + return penalized_loss( |
| 93 | + reconstruction_loss, self.encode_dim, self.regularization_factor |
| 94 | + ) |
| 95 | + |
| 96 | + # optimize with tf.function |
| 97 | + optimized_func: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] = tf.function( |
| 98 | + custom_loss |
| 99 | + ) |
| 100 | + |
| 101 | + # get wrapped function |
| 102 | + return optimized_func |
0 commit comments