Skip to content

Commit 5a3cc6e

Browse files
committed
Fix step, docs in qtl and contamination
1 parent 2e5436f commit 5a3cc6e

File tree

1 file changed

+51
-38
lines changed

1 file changed

+51
-38
lines changed

common/uq_keras_utils.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313

1414
import numpy as np
1515

16+
from typing import List, Optional, Tuple, Type, Union
17+
1618
from scipy.stats import norm, cauchy
1719

20+
Array = Type[np.ndarray]
21+
1822
piSQ = np.pi**2
1923

2024
###################################################################
2125

2226
# For Abstention Model
2327

2428

25-
def abstention_loss(alpha, mask):
29+
def abstention_loss(alpha, mask: Array):
2630
""" Function to compute abstention loss.
2731
It is composed by two terms:
2832
(i) original loss of the multiclass classification problem,
@@ -62,7 +66,7 @@ def loss(y_true, y_pred):
6266
return loss
6367

6468

65-
def sparse_abstention_loss(alpha, mask):
69+
def sparse_abstention_loss(alpha, mask: Array):
6670
""" Function to compute abstention loss.
6771
It is composed by two terms:
6872
(i) original loss of the multiclass classification problem,
@@ -100,7 +104,7 @@ def loss(y_true, y_pred):
100104
return loss
101105

102106

103-
def abstention_acc_metric(nb_classes):
107+
def abstention_acc_metric(nb_classes: Union[int, Array]):
104108
""" Abstained accuracy:
105109
Function to estimate accuracy over the predicted samples
106110
after removing the samples where the model is abstaining.
@@ -138,7 +142,7 @@ def metric(y_true, y_pred):
138142
return metric
139143

140144

141-
def sparse_abstention_acc_metric(nb_classes):
145+
def sparse_abstention_acc_metric(nb_classes: Union[int, Array]):
142146
""" Abstained accuracy:
143147
Function to estimate accuracy over the predicted samples
144148
after removing the samples where the model is abstaining.
@@ -179,7 +183,7 @@ def metric(y_true, y_pred):
179183
return metric
180184

181185

182-
def abstention_metric(nb_classes):
186+
def abstention_metric(nb_classes: Union[int, Array]):
183187
""" Function to estimate fraction of the samples where the model is abstaining.
184188
185189
Parameters
@@ -209,7 +213,7 @@ def metric(y_true, y_pred):
209213
return metric
210214

211215

212-
def acc_class_i_metric(class_i):
216+
def acc_class_i_metric(class_i: int):
213217
""" Function to estimate accuracy over the ith class prediction.
214218
This estimation is global (i.e. abstaining samples are not removed)
215219
@@ -259,7 +263,7 @@ def metric(y_true, y_pred):
259263
return metric
260264

261265

262-
def abstention_acc_class_i_metric(nb_classes, class_i):
266+
def abstention_acc_class_i_metric(nb_classes: Union[int, Array], class_i: int):
263267
""" Function to estimate accuracy over the class i prediction after removing the samples where the model is abstaining.
264268
265269
Parameters
@@ -310,7 +314,7 @@ def metric(y_true, y_pred):
310314
return metric
311315

312316

313-
def abstention_class_i_metric(nb_classes, class_i):
317+
def abstention_class_i_metric(nb_classes: Union[int, Array], class_i: int):
314318
""" Function to estimate fraction of the samples where the model is abstaining in class i.
315319
316320
Parameters
@@ -357,7 +361,7 @@ class AbstentionAdapt_Callback(Callback):
357361
The factor alpha is modified if the current abstention accuracy is less than the minimum accuracy set or if the current abstention fraction is greater than the maximum fraction set. Thresholds for minimum and maximum correction factors are computed and the correction over alpha is not allowed to be less or greater than them, respectively, to avoid huge swings in the abstention loss evolution.
358362
"""
359363

360-
def __init__(self, acc_monitor, abs_monitor, alpha0, init_abs_epoch=4, alpha_scale_factor=0.8, min_abs_acc=0.9, max_abs_frac=0.4, acc_gain=5.0, abs_gain=1.0):
364+
def __init__(self, acc_monitor, abs_monitor, alpha0: float, init_abs_epoch: int = 4, alpha_scale_factor: float = 0.8, min_abs_acc: float = 0.9, max_abs_frac: float = 0.4, acc_gain: float = 5.0, abs_gain: float = 1.0):
361365
""" Initializer of the AbstentionAdapt_Callback.
362366
Parameters
363367
----------
@@ -391,9 +395,9 @@ def __init__(self, acc_monitor, abs_monitor, alpha0, init_abs_epoch=4, alpha_sca
391395
self.max_abs_frac = max_abs_frac # maximum abstention fraction (value specified as parameter of the run)
392396
self.acc_gain = acc_gain # factor for adjusting alpha scale
393397
self.abs_gain = abs_gain # factor for adjusting alpha scale
394-
self.alphavalues = [] # array to store alpha evolution
398+
self.alphavalues: List[float] = [] # array to store alpha evolution
395399

396-
def on_epoch_end(self, epoch, logs=None):
400+
def on_epoch_end(self, epoch: int, logs=None):
397401
""" Updates the weight of abstention term on epoch end.
398402
Parameters
399403
----------
@@ -435,7 +439,7 @@ def on_epoch_end(self, epoch, logs=None):
435439
self.alphavalues.append(new_alpha_val)
436440

437441

438-
def modify_labels(numclasses_out, ytrain, ytest, yval=None):
442+
def modify_labels(numclasses_out: int, ytrain: Array, ytest: Array, yval: Optional[Array] = None) -> Tuple[Array, ...]:
439443
""" This function generates a categorical representation with a class added for indicating abstention.
440444
441445
Parameters
@@ -489,7 +493,7 @@ def modify_labels(numclasses_out, ytrain, ytest, yval=None):
489493
###################################################################
490494

491495

492-
def add_model_output(modelIn, mode=None, num_add=None, activation=None):
496+
def add_model_output(modelIn, mode: Optional[str] = None, num_add: Optional[int] = None, activation: Optional[str] = None):
493497
""" This function modifies the last dense layer in the passed keras model. The modification includes adding units and optionally changing the activation function.
494498
495499
Parameters
@@ -567,7 +571,7 @@ def add_model_output(modelIn, mode=None, num_add=None, activation=None):
567571
# UQ regression - utilities
568572

569573

570-
def r2_heteroscedastic_metric(nout):
574+
def r2_heteroscedastic_metric(nout: int):
571575
"""This function computes the r2 for the heteroscedastic model. The r2 is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
572576
573577
Parameters
@@ -598,7 +602,7 @@ def metric(y_true, y_pred):
598602
return metric
599603

600604

601-
def mae_heteroscedastic_metric(nout):
605+
def mae_heteroscedastic_metric(nout: int):
602606
"""This function computes the mean absolute error (mae) for the heteroscedastic model. The mae is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
603607
604608
Parameters
@@ -626,7 +630,7 @@ def metric(y_true, y_pred):
626630
return metric
627631

628632

629-
def mse_heteroscedastic_metric(nout):
633+
def mse_heteroscedastic_metric(nout: int):
630634
"""This function computes the mean squared error (mse) for the heteroscedastic model. The mse is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
631635
632636
Parameters
@@ -654,7 +658,7 @@ def metric(y_true, y_pred):
654658
return metric
655659

656660

657-
def meanS_heteroscedastic_metric(nout):
661+
def meanS_heteroscedastic_metric(nout: int):
658662
"""This function computes the mean log of the variance (log S) for the heteroscedastic model. The mean log is computed over the standard deviation prediction and the mean prediction is not taken into account.
659663
660664
Parameters
@@ -682,7 +686,7 @@ def metric(y_true, y_pred):
682686
return metric
683687

684688

685-
def heteroscedastic_loss(nout):
689+
def heteroscedastic_loss(nout: int):
686690
"""This function computes the heteroscedastic loss for the heteroscedastic model. Both mean and standard deviation predictions are taken into account.
687691
688692
Parameters
@@ -717,7 +721,7 @@ def loss(y_true, y_pred):
717721
return loss
718722

719723

720-
def quantile_loss(quantile, y_true, y_pred):
724+
def quantile_loss(quantile: float, y_true, y_pred):
721725
"""This function computes the quantile loss for a given quantile fraction.
722726
723727
Parameters
@@ -734,7 +738,7 @@ def quantile_loss(quantile, y_true, y_pred):
734738
return K.mean(K.maximum(quantile * error, (quantile - 1) * error))
735739

736740

737-
def triple_quantile_loss(nout, lowquantile, highquantile):
741+
def triple_quantile_loss(nout: int, lowquantile: float, highquantile: float):
738742
"""This function computes the quantile loss for the median and low and high quantiles. The median is given twice the weight of the other components.
739743
740744
Parameters
@@ -759,20 +763,20 @@ def loss(y_true, y_pred):
759763

760764
y_shape = K.shape(y_true)
761765
if nout > 1:
762-
y_out0 = K.reshape(y_pred[:, 0::nout], y_shape)
763-
y_out1 = K.reshape(y_pred[:, 1::nout], y_shape)
764-
y_out2 = K.reshape(y_pred[:, 2::nout], y_shape)
766+
y_qtl0 = K.reshape(y_pred[:, 0::3], y_shape)
767+
y_qtl1 = K.reshape(y_pred[:, 1::3], y_shape)
768+
y_qtl2 = K.reshape(y_pred[:, 2::3], y_shape)
765769
else:
766-
y_out0 = K.reshape(y_pred[:, 0], y_shape)
767-
y_out1 = K.reshape(y_pred[:, 1], y_shape)
768-
y_out2 = K.reshape(y_pred[:, 2], y_shape)
770+
y_qtl0 = K.reshape(y_pred[:, 0], y_shape)
771+
y_qtl1 = K.reshape(y_pred[:, 1], y_shape)
772+
y_qtl2 = K.reshape(y_pred[:, 2], y_shape)
769773

770-
return quantile_loss(lowquantile, y_true, y_out1) + quantile_loss(highquantile, y_true, y_out2) + 2. * quantile_loss(0.5, y_true, y_out0)
774+
return quantile_loss(lowquantile, y_true, y_qtl1) + quantile_loss(highquantile, y_true, y_qtl2) + 2. * quantile_loss(0.5, y_true, y_qtl0)
771775

772776
return loss
773777

774778

775-
def quantile_metric(nout, index, quantile):
779+
def quantile_metric(nout: int, index: int, quantile: float):
776780
"""This function computes the quantile metric for a given quantile and corresponding output index. This is provided as a metric to track evolution while training.
777781
778782
Parameters
@@ -795,10 +799,10 @@ def metric(y_true, y_pred):
795799
"""
796800
y_shape = K.shape(y_true)
797801
if nout > 1:
798-
y_out = K.reshape(y_pred[:, index::nout], y_shape)
802+
y_qtl = K.reshape(y_pred[:, index::3], y_shape)
799803
else:
800-
y_out = K.reshape(y_pred[:, index], y_shape)
801-
return quantile_loss(quantile, y_true, y_out)
804+
y_qtl = K.reshape(y_pred[:, index], y_shape)
805+
return quantile_loss(quantile, y_true, y_qtl)
802806

803807
metric.__name__ = 'quantile_{}'.format(quantile)
804808
return metric
@@ -809,7 +813,7 @@ def metric(y_true, y_pred):
809813
# For the Contamination Model
810814

811815

812-
def add_index_to_output(y_train):
816+
def add_index_to_output(y_train: Array) -> Array:
813817
""" This function adds a column to the training output to store the indices of the corresponding samples in the training set.
814818
815819
Parameters
@@ -819,12 +823,16 @@ def add_index_to_output(y_train):
819823
"""
820824
# Add indices to y
821825
y_train_index = range(y_train.shape[0])
822-
y_train_augmented = np.vstack([y_train, y_train_index]).T
826+
if y_train.ndim > 1:
827+
shp = (y_train.shape[0], 1)
828+
y_train_augmented = np.hstack([y_train, np.reshape(y_train_index, shp)])
829+
else:
830+
y_train_augmented = np.vstack([y_train, y_train_index]).T
823831

824832
return y_train_augmented
825833

826834

827-
def contamination_loss(nout, T_k, a, sigmaSQ, gammaSQ):
835+
def contamination_loss(nout: int, T_k, a, sigmaSQ, gammaSQ):
828836
""" Function to compute contamination loss. It is composed by two terms: (i) the loss with respect to the normal distribution that models the distribution of the training data samples, (ii) the loss with respect to the Cauchy distribution that models the distribution of the outlier samples. Note that the evaluation of this contamination loss function does not make sense for any data different to the training set. This is because latent variables are only defined for samples in the training set.
829837
830838
Parameters
@@ -884,6 +892,11 @@ def __init__(self, x, y, a_max=0.99):
884892
Maximum value of a variable to allow
885893
"""
886894
super(Contamination_Callback, self).__init__()
895+
if y.ndim > 1:
896+
if y.shape[1] > 1:
897+
raise Exception(
898+
'ERROR ! Contamination model can be applied to one-output regression, but provided training data has: '
899+
+ str(y.ndim) + 'outpus... Exiting')
887900

888901
self.x = x # Features of training set
889902
self.y = y # Output of training set
@@ -904,7 +917,7 @@ def __init__(self, x, y, a_max=0.99):
904917
self.sigmaSQvalues = [] # array to store sigmaSQ evolution
905918
self.gammaSQvalues = [] # array to store gammaSQ evolution
906919

907-
def on_epoch_end(self, epoch, logs={}):
920+
def on_epoch_end(self, epoch: int, logs={}):
908921
""" Updates the parameters of the distributions in the contamination model on epoch end. The parameters updated are: 'a' for the global weight of the membership to the normal distribution, 'sigmaSQ' for the variance of the normal distribution and 'gammaSQ' for the scale of the Cauchy distribution of outliers. The latent variables are updated as well: 'T_k' describing in the first column the probability of membership to normal distribution and in the second column probability of membership to the Cauchy distribution i.e. outlier. Stores evolution of global parameters (a, sigmaSQ and gammaSQ).
909922
910923
Parameters
@@ -953,7 +966,7 @@ def on_epoch_end(self, epoch, logs={}):
953966
self.gammaSQvalues.append(gammaSQ_eval)
954967

955968

956-
def mse_contamination_metric(nout):
969+
def mse_contamination_metric(nout: int):
957970
"""This function computes the mean squared error (mse) for the contamination model. The mse is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
958971
959972
Parameters
@@ -977,7 +990,7 @@ def metric(y_true, y_pred):
977990
return metric
978991

979992

980-
def mae_contamination_metric(nout):
993+
def mae_contamination_metric(nout: int):
981994
"""This function computes the mean absolute error (mae) for the contamination model. The mae is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
982995
983996
Parameters
@@ -1001,7 +1014,7 @@ def metric(y_true, y_pred):
10011014
return metric
10021015

10031016

1004-
def r2_contamination_metric(nout):
1017+
def r2_contamination_metric(nout: int):
10051018
"""This function computes the r2 for the contamination model. The r2 is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
10061019
10071020
Parameters

0 commit comments

Comments
 (0)