You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
""" Function to estimate fraction of the samples where the model is abstaining in class i.
315
319
316
320
Parameters
@@ -357,7 +361,7 @@ class AbstentionAdapt_Callback(Callback):
357
361
The factor alpha is modified if the current abstention accuracy is less than the minimum accuracy set or if the current abstention fraction is greater than the maximum fraction set. Thresholds for minimum and maximum correction factors are computed and the correction over alpha is not allowed to be less or greater than them, respectively, to avoid huge swings in the abstention loss evolution.
""" This function modifies the last dense layer in the passed keras model. The modification includes adding units and optionally changing the activation function.
"""This function computes the r2 for the heteroscedastic model. The r2 is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
572
576
573
577
Parameters
@@ -598,7 +602,7 @@ def metric(y_true, y_pred):
598
602
returnmetric
599
603
600
604
601
-
defmae_heteroscedastic_metric(nout):
605
+
defmae_heteroscedastic_metric(nout: int):
602
606
"""This function computes the mean absolute error (mae) for the heteroscedastic model. The mae is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
603
607
604
608
Parameters
@@ -626,7 +630,7 @@ def metric(y_true, y_pred):
626
630
returnmetric
627
631
628
632
629
-
defmse_heteroscedastic_metric(nout):
633
+
defmse_heteroscedastic_metric(nout: int):
630
634
"""This function computes the mean squared error (mse) for the heteroscedastic model. The mse is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
631
635
632
636
Parameters
@@ -654,7 +658,7 @@ def metric(y_true, y_pred):
654
658
returnmetric
655
659
656
660
657
-
defmeanS_heteroscedastic_metric(nout):
661
+
defmeanS_heteroscedastic_metric(nout: int):
658
662
"""This function computes the mean log of the variance (log S) for the heteroscedastic model. The mean log is computed over the standard deviation prediction and the mean prediction is not taken into account.
659
663
660
664
Parameters
@@ -682,7 +686,7 @@ def metric(y_true, y_pred):
682
686
returnmetric
683
687
684
688
685
-
defheteroscedastic_loss(nout):
689
+
defheteroscedastic_loss(nout: int):
686
690
"""This function computes the heteroscedastic loss for the heteroscedastic model. Both mean and standard deviation predictions are taken into account.
"""This function computes the quantile metric for a given quantile and corresponding output index. This is provided as a metric to track evolution while training.
777
781
778
782
Parameters
@@ -795,10 +799,10 @@ def metric(y_true, y_pred):
795
799
"""
796
800
y_shape=K.shape(y_true)
797
801
ifnout>1:
798
-
y_out=K.reshape(y_pred[:, index::nout], y_shape)
802
+
y_qtl=K.reshape(y_pred[:, index::3], y_shape)
799
803
else:
800
-
y_out=K.reshape(y_pred[:, index], y_shape)
801
-
returnquantile_loss(quantile, y_true, y_out)
804
+
y_qtl=K.reshape(y_pred[:, index], y_shape)
805
+
returnquantile_loss(quantile, y_true, y_qtl)
802
806
803
807
metric.__name__='quantile_{}'.format(quantile)
804
808
returnmetric
@@ -809,7 +813,7 @@ def metric(y_true, y_pred):
809
813
# For the Contamination Model
810
814
811
815
812
-
defadd_index_to_output(y_train):
816
+
defadd_index_to_output(y_train: Array) ->Array:
813
817
""" This function adds a column to the training output to store the indices of the corresponding samples in the training set.
defcontamination_loss(nout, T_k, a, sigmaSQ, gammaSQ):
835
+
defcontamination_loss(nout: int, T_k, a, sigmaSQ, gammaSQ):
828
836
""" Function to compute contamination loss. It is composed by two terms: (i) the loss with respect to the normal distribution that models the distribution of the training data samples, (ii) the loss with respect to the Cauchy distribution that models the distribution of the outlier samples. Note that the evaluation of this contamination loss function does not make sense for any data different to the training set. This is because latent variables are only defined for samples in the training set.
829
837
830
838
Parameters
@@ -884,6 +892,11 @@ def __init__(self, x, y, a_max=0.99):
884
892
Maximum value of a variable to allow
885
893
"""
886
894
super(Contamination_Callback, self).__init__()
895
+
ify.ndim>1:
896
+
ify.shape[1] >1:
897
+
raiseException(
898
+
'ERROR ! Contamination model can be applied to one-output regression, but provided training data has: '
899
+
+str(y.ndim) +'outpus... Exiting')
887
900
888
901
self.x=x# Features of training set
889
902
self.y=y# Output of training set
@@ -904,7 +917,7 @@ def __init__(self, x, y, a_max=0.99):
904
917
self.sigmaSQvalues= [] # array to store sigmaSQ evolution
905
918
self.gammaSQvalues= [] # array to store gammaSQ evolution
906
919
907
-
defon_epoch_end(self, epoch, logs={}):
920
+
defon_epoch_end(self, epoch: int, logs={}):
908
921
""" Updates the parameters of the distributions in the contamination model on epoch end. The parameters updated are: 'a' for the global weight of the membership to the normal distribution, 'sigmaSQ' for the variance of the normal distribution and 'gammaSQ' for the scale of the Cauchy distribution of outliers. The latent variables are updated as well: 'T_k' describing in the first column the probability of membership to normal distribution and in the second column probability of membership to the Cauchy distribution i.e. outlier. Stores evolution of global parameters (a, sigmaSQ and gammaSQ).
"""This function computes the mean squared error (mse) for the contamination model. The mse is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
958
971
959
972
Parameters
@@ -977,7 +990,7 @@ def metric(y_true, y_pred):
977
990
returnmetric
978
991
979
992
980
-
defmae_contamination_metric(nout):
993
+
defmae_contamination_metric(nout: int):
981
994
"""This function computes the mean absolute error (mae) for the contamination model. The mae is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
982
995
983
996
Parameters
@@ -1001,7 +1014,7 @@ def metric(y_true, y_pred):
1001
1014
returnmetric
1002
1015
1003
1016
1004
-
defr2_contamination_metric(nout):
1017
+
defr2_contamination_metric(nout: int):
1005
1018
"""This function computes the r2 for the contamination model. The r2 is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
0 commit comments