You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
""" Function to estimate fraction of the samples where the model is abstaining in class i.
319
315
320
316
Parameters
@@ -361,7 +357,7 @@ class AbstentionAdapt_Callback(Callback):
361
357
The factor alpha is modified if the current abstention accuracy is less than the minimum accuracy set or if the current abstention fraction is greater than the maximum fraction set. Thresholds for minimum and maximum correction factors are computed and the correction over alpha is not allowed to be less or greater than them, respectively, to avoid huge swings in the abstention loss evolution.
""" This function modifies the last dense layer in the passed keras model. The modification includes adding units and optionally changing the activation function.
"""This function computes the r2 for the heteroscedastic model. The r2 is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
576
572
577
573
Parameters
@@ -602,7 +598,7 @@ def metric(y_true, y_pred):
602
598
returnmetric
603
599
604
600
605
-
defmae_heteroscedastic_metric(nout: int):
601
+
defmae_heteroscedastic_metric(nout):
606
602
"""This function computes the mean absolute error (mae) for the heteroscedastic model. The mae is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
607
603
608
604
Parameters
@@ -630,7 +626,7 @@ def metric(y_true, y_pred):
630
626
returnmetric
631
627
632
628
633
-
defmse_heteroscedastic_metric(nout: int):
629
+
defmse_heteroscedastic_metric(nout):
634
630
"""This function computes the mean squared error (mse) for the heteroscedastic model. The mse is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
635
631
636
632
Parameters
@@ -658,7 +654,7 @@ def metric(y_true, y_pred):
658
654
returnmetric
659
655
660
656
661
-
defmeanS_heteroscedastic_metric(nout: int):
657
+
defmeanS_heteroscedastic_metric(nout):
662
658
"""This function computes the mean log of the variance (log S) for the heteroscedastic model. The mean log is computed over the standard deviation prediction and the mean prediction is not taken into account.
663
659
664
660
Parameters
@@ -686,7 +682,7 @@ def metric(y_true, y_pred):
686
682
returnmetric
687
683
688
684
689
-
defheteroscedastic_loss(nout: int):
685
+
defheteroscedastic_loss(nout):
690
686
"""This function computes the heteroscedastic loss for the heteroscedastic model. Both mean and standard deviation predictions are taken into account.
"""This function computes the quantile metric for a given quantile and corresponding output index. This is provided as a metric to track evolution while training.
781
777
782
778
Parameters
@@ -813,7 +809,7 @@ def metric(y_true, y_pred):
813
809
# For the Contamination Model
814
810
815
811
816
-
defadd_index_to_output(y_train: Array) ->Array:
812
+
defadd_index_to_output(y_train):
817
813
""" This function adds a column to the training output to store the indices of the corresponding samples in the training set.
defcontamination_loss(nout: int, T_k, a, sigmaSQ, gammaSQ):
831
+
defcontamination_loss(nout, T_k, a, sigmaSQ, gammaSQ):
836
832
""" Function to compute contamination loss. It is composed by two terms: (i) the loss with respect to the normal distribution that models the distribution of the training data samples, (ii) the loss with respect to the Cauchy distribution that models the distribution of the outlier samples. Note that the evaluation of this contamination loss function does not make sense for any data different to the training set. This is because latent variables are only defined for samples in the training set.
837
833
838
834
Parameters
@@ -917,7 +913,7 @@ def __init__(self, x, y, a_max=0.99):
917
913
self.sigmaSQvalues= [] # array to store sigmaSQ evolution
918
914
self.gammaSQvalues= [] # array to store gammaSQ evolution
919
915
920
-
defon_epoch_end(self, epoch: int, logs={}):
916
+
defon_epoch_end(self, epoch, logs={}):
921
917
""" Updates the parameters of the distributions in the contamination model on epoch end. The parameters updated are: 'a' for the global weight of the membership to the normal distribution, 'sigmaSQ' for the variance of the normal distribution and 'gammaSQ' for the scale of the Cauchy distribution of outliers. The latent variables are updated as well: 'T_k' describing in the first column the probability of membership to normal distribution and in the second column probability of membership to the Cauchy distribution i.e. outlier. Stores evolution of global parameters (a, sigmaSQ and gammaSQ).
"""This function computes the mean squared error (mse) for the contamination model. The mse is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
971
967
972
968
Parameters
@@ -990,7 +986,7 @@ def metric(y_true, y_pred):
990
986
returnmetric
991
987
992
988
993
-
defmae_contamination_metric(nout: int):
989
+
defmae_contamination_metric(nout):
994
990
"""This function computes the mean absolute error (mae) for the contamination model. The mae is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
995
991
996
992
Parameters
@@ -1014,7 +1010,7 @@ def metric(y_true, y_pred):
1014
1010
returnmetric
1015
1011
1016
1012
1017
-
defr2_contamination_metric(nout: int):
1013
+
defr2_contamination_metric(nout):
1018
1014
"""This function computes the r2 for the contamination model. The r2 is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
0 commit comments