@@ -72,7 +72,7 @@ def _average_best(fold_losses: np.ndarray, proportion: float = 0.05, axis: int =
7272 return _average (best_losses , axis = axis )
7373
7474
75- def _average (fold_losses : np .ndarray , axis : int = 0 ) -> float :
75+ def _average (fold_losses : np .ndarray , axis : int = 0 , ** kwargs ) -> float :
7676 """
7777 Compute the average of the input array along the specified axis.
7878
@@ -90,7 +90,7 @@ def _average(fold_losses: np.ndarray, axis: int = 0) -> float:
9090 return np .average (fold_losses , axis = axis ).item ()
9191
9292
93- def _best_worst (fold_losses : np .ndarray , axis : int = 0 ) -> float :
93+ def _best_worst (fold_losses : np .ndarray , axis : int = 0 , ** kwargs ) -> float :
9494 """
9595 Compute the maximum value of the input array along the specified axis.
9696
@@ -108,7 +108,7 @@ def _best_worst(fold_losses: np.ndarray, axis: int = 0) -> float:
108108 return np .max (fold_losses , axis = axis ).item ()
109109
110110
111- def _std (fold_losses : np .ndarray , axis : int = 0 ) -> float :
111+ def _std (fold_losses : np .ndarray , axis : int = 0 , ** kwargs ) -> float :
112112 """
113113 Compute the standard deviation of the input array along the specified axis.
114114
@@ -265,9 +265,14 @@ def compute_loss(
265265 if self .loss_type == "chi2" :
266266 # calculate statistics of chi2 over replicas for a given k-fold_statistic
267267
268- ### Experiment:
269- # Use the validation loss as the loss
270- # summed with how far from 2 are we for the kfold
268+ # Construct the final loss as a sum of
269+ # 1. The validation chi2
270+ # 2. The distance to 2 for the kfold chi2
271+ # If a proportion allow as a keyword argument, use 80% and 10%
272+ # as a proxy of
273+ # "80% of the replicas should be good, but only a small % has to cover the folds"
274+ # The values of 80% and 10% are completely empirical and should be investigated further
275+
271276 validation_loss_average = self .reduce_over_replicas (validation_loss , proportion = 0.8 )
272277 kfold_loss_average = self .reduce_over_replicas (kfold_loss , proportion = 0.1 )
273278 loss = validation_loss_average + (max (kfold_loss_average , 2.0 ) - 2.0 )
0 commit comments