Skip to content

Commit 65e8089

Browse files
authored
refactor: refactor scaling module (queens-py#220)
1 parent 5493741 commit 65e8089

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

src/queens/models/surrogates/gaussian_neural_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from queens.models.surrogates._surrogate import Surrogate
2525
from queens.utils.configure_tensorflow import configure_keras, configure_tensorflow
2626
from queens.utils.logger_settings import log_init_args
27-
from queens.utils.random_process_scaler import VALID_SCALER
27+
from queens.utils.scaling import VALID_SCALER
2828
from queens.utils.valid_options import get_option
2929
from queens.visualization.gaussian_neural_network_vis import plot_loss
3030

src/queens/models/surrogates/jitted_gaussian_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import queens.models.surrogates.utils.kernel_jitted as utils_jitted
2323
from queens.models.surrogates._surrogate import Surrogate
2424
from queens.utils.logger_settings import log_init_args
25-
from queens.utils.random_process_scaler import VALID_SCALER
25+
from queens.utils.scaling import VALID_SCALER
2626
from queens.utils.valid_options import get_option
2727
from queens.visualization.gnuplot_vis import gnuplot_gp_convergence
2828

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,11 @@
2323
class Scaler(metaclass=abc.ABCMeta):
2424
"""Base class for general scaling classes.
2525
26-
The purpose of these classes is the scaling of training data.
27-
28-
Attributes:
29-
mean: Mean-values of the data-matrix (column-wise).
30-
standard_deviation: Standard deviation of the data-matrix (per column).
26+
The purpose of these classes is the scaling of data.
3127
"""
3228

3329
def __init__(self) -> None:
34-
"""Initialise scaler.
35-
36-
Returns:
37-
Instance of the Scaler Class (obj)
38-
"""
39-
self.mean: np.ndarray | None = None
40-
self.standard_deviation: np.ndarray | None = None
30+
"""Initialize scaler."""
4131

4232
@abc.abstractmethod
4333
def fit(self, x_mat: np.ndarray) -> None:
@@ -78,8 +68,18 @@ class StandardScaler(Scaler):
7868
In case a stochastic process is trained on the scaled data, inverse
7969
rescaling is implemented to recover the correct mean and standard
8070
deviation prediction for the posterior process.
71+
72+
Attributes:
73+
mean: Mean-values of the data-matrix (column-wise).
74+
standard_deviation: Standard deviation of the data-matrix (per column).
8175
"""
8276

77+
def __init__(self) -> None:
78+
"""Initialize standard scaler."""
79+
super().__init__()
80+
self.mean: np.ndarray | None = None
81+
self.standard_deviation: np.ndarray | None = None
82+
8383
def fit(self, x_mat: np.ndarray) -> None:
8484
"""Fit/calculate the scaling based on the input samples.
8585

0 commit comments

Comments
 (0)