Skip to content

Commit 35571b1

Browse files
committed
Remove type annotations
1 parent 5a3cc6e commit 35571b1

File tree

1 file changed

+27
-31
lines changed

1 file changed

+27
-31
lines changed

common/uq_keras_utils.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,16 @@
1313

1414
import numpy as np
1515

16-
from typing import List, Optional, Tuple, Type, Union
17-
1816
from scipy.stats import norm, cauchy
1917

20-
Array = Type[np.ndarray]
21-
2218
piSQ = np.pi**2
2319

2420
###################################################################
2521

2622
# For Abstention Model
2723

2824

29-
def abstention_loss(alpha, mask: Array):
25+
def abstention_loss(alpha, mask):
3026
""" Function to compute abstention loss.
3127
It is composed by two terms:
3228
(i) original loss of the multiclass classification problem,
@@ -66,7 +62,7 @@ def loss(y_true, y_pred):
6662
return loss
6763

6864

69-
def sparse_abstention_loss(alpha, mask: Array):
65+
def sparse_abstention_loss(alpha, mask):
7066
""" Function to compute abstention loss.
7167
It is composed by two terms:
7268
(i) original loss of the multiclass classification problem,
@@ -104,7 +100,7 @@ def loss(y_true, y_pred):
104100
return loss
105101

106102

107-
def abstention_acc_metric(nb_classes: Union[int, Array]):
103+
def abstention_acc_metric(nb_classes):
108104
""" Abstained accuracy:
109105
Function to estimate accuracy over the predicted samples
110106
after removing the samples where the model is abstaining.
@@ -142,7 +138,7 @@ def metric(y_true, y_pred):
142138
return metric
143139

144140

145-
def sparse_abstention_acc_metric(nb_classes: Union[int, Array]):
141+
def sparse_abstention_acc_metric(nb_classes):
146142
""" Abstained accuracy:
147143
Function to estimate accuracy over the predicted samples
148144
after removing the samples where the model is abstaining.
@@ -183,7 +179,7 @@ def metric(y_true, y_pred):
183179
return metric
184180

185181

186-
def abstention_metric(nb_classes: Union[int, Array]):
182+
def abstention_metric(nb_classes):
187183
""" Function to estimate fraction of the samples where the model is abstaining.
188184
189185
Parameters
@@ -213,7 +209,7 @@ def metric(y_true, y_pred):
213209
return metric
214210

215211

216-
def acc_class_i_metric(class_i: int):
212+
def acc_class_i_metric(class_i):
217213
""" Function to estimate accuracy over the ith class prediction.
218214
This estimation is global (i.e. abstaining samples are not removed)
219215
@@ -263,7 +259,7 @@ def metric(y_true, y_pred):
263259
return metric
264260

265261

266-
def abstention_acc_class_i_metric(nb_classes: Union[int, Array], class_i: int):
262+
def abstention_acc_class_i_metric(nb_classes, class_i):
267263
""" Function to estimate accuracy over the class i prediction after removing the samples where the model is abstaining.
268264
269265
Parameters
@@ -314,7 +310,7 @@ def metric(y_true, y_pred):
314310
return metric
315311

316312

317-
def abstention_class_i_metric(nb_classes: Union[int, Array], class_i: int):
313+
def abstention_class_i_metric(nb_classes, class_i):
318314
""" Function to estimate fraction of the samples where the model is abstaining in class i.
319315
320316
Parameters
@@ -361,7 +357,7 @@ class AbstentionAdapt_Callback(Callback):
361357
The factor alpha is modified if the current abstention accuracy is less than the minimum accuracy set or if the current abstention fraction is greater than the maximum fraction set. Thresholds for minimum and maximum correction factors are computed and the correction over alpha is not allowed to be less or greater than them, respectively, to avoid huge swings in the abstention loss evolution.
362358
"""
363359

364-
def __init__(self, acc_monitor, abs_monitor, alpha0: float, init_abs_epoch: int = 4, alpha_scale_factor: float = 0.8, min_abs_acc: float = 0.9, max_abs_frac: float = 0.4, acc_gain: float = 5.0, abs_gain: float = 1.0):
360+
def __init__(self, acc_monitor, abs_monitor, alpha0, init_abs_epoch=4, alpha_scale_factor=0.8, min_abs_acc=0.9, max_abs_frac=0.4, acc_gain=5.0, abs_gain=1.0):
365361
""" Initializer of the AbstentionAdapt_Callback.
366362
Parameters
367363
----------
@@ -395,9 +391,9 @@ def __init__(self, acc_monitor, abs_monitor, alpha0: float, init_abs_epoch: int
395391
self.max_abs_frac = max_abs_frac # maximum abstention fraction (value specified as parameter of the run)
396392
self.acc_gain = acc_gain # factor for adjusting alpha scale
397393
self.abs_gain = abs_gain # factor for adjusting alpha scale
398-
self.alphavalues: List[float] = [] # array to store alpha evolution
394+
self.alphavalues = [] # array to store alpha evolution
399395

400-
def on_epoch_end(self, epoch: int, logs=None):
396+
def on_epoch_end(self, epoch, logs=None):
401397
""" Updates the weight of abstention term on epoch end.
402398
Parameters
403399
----------
@@ -439,7 +435,7 @@ def on_epoch_end(self, epoch: int, logs=None):
439435
self.alphavalues.append(new_alpha_val)
440436

441437

442-
def modify_labels(numclasses_out: int, ytrain: Array, ytest: Array, yval: Optional[Array] = None) -> Tuple[Array, ...]:
438+
def modify_labels(numclasses_out, ytrain, ytest, yval=None):
443439
""" This function generates a categorical representation with a class added for indicating abstention.
444440
445441
Parameters
@@ -493,7 +489,7 @@ def modify_labels(numclasses_out: int, ytrain: Array, ytest: Array, yval: Option
493489
###################################################################
494490

495491

496-
def add_model_output(modelIn, mode: Optional[str] = None, num_add: Optional[int] = None, activation: Optional[str] = None):
492+
def add_model_output(modelIn, mode=None, num_add=None, activation=None):
497493
""" This function modifies the last dense layer in the passed keras model. The modification includes adding units and optionally changing the activation function.
498494
499495
Parameters
@@ -571,7 +567,7 @@ def add_model_output(modelIn, mode: Optional[str] = None, num_add: Optional[int]
571567
# UQ regression - utilities
572568

573569

574-
def r2_heteroscedastic_metric(nout: int):
570+
def r2_heteroscedastic_metric(nout):
575571
"""This function computes the r2 for the heteroscedastic model. The r2 is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
576572
577573
Parameters
@@ -602,7 +598,7 @@ def metric(y_true, y_pred):
602598
return metric
603599

604600

605-
def mae_heteroscedastic_metric(nout: int):
601+
def mae_heteroscedastic_metric(nout):
606602
"""This function computes the mean absolute error (mae) for the heteroscedastic model. The mae is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
607603
608604
Parameters
@@ -630,7 +626,7 @@ def metric(y_true, y_pred):
630626
return metric
631627

632628

633-
def mse_heteroscedastic_metric(nout: int):
629+
def mse_heteroscedastic_metric(nout):
634630
"""This function computes the mean squared error (mse) for the heteroscedastic model. The mse is computed over the prediction of the mean and the standard deviation prediction is not taken into account.
635631
636632
Parameters
@@ -658,7 +654,7 @@ def metric(y_true, y_pred):
658654
return metric
659655

660656

661-
def meanS_heteroscedastic_metric(nout: int):
657+
def meanS_heteroscedastic_metric(nout):
662658
"""This function computes the mean log of the variance (log S) for the heteroscedastic model. The mean log is computed over the standard deviation prediction and the mean prediction is not taken into account.
663659
664660
Parameters
@@ -686,7 +682,7 @@ def metric(y_true, y_pred):
686682
return metric
687683

688684

689-
def heteroscedastic_loss(nout: int):
685+
def heteroscedastic_loss(nout):
690686
"""This function computes the heteroscedastic loss for the heteroscedastic model. Both mean and standard deviation predictions are taken into account.
691687
692688
Parameters
@@ -721,7 +717,7 @@ def loss(y_true, y_pred):
721717
return loss
722718

723719

724-
def quantile_loss(quantile: float, y_true, y_pred):
720+
def quantile_loss(quantile, y_true, y_pred):
725721
"""This function computes the quantile loss for a given quantile fraction.
726722
727723
Parameters
@@ -738,7 +734,7 @@ def quantile_loss(quantile: float, y_true, y_pred):
738734
return K.mean(K.maximum(quantile * error, (quantile - 1) * error))
739735

740736

741-
def triple_quantile_loss(nout: int, lowquantile: float, highquantile: float):
737+
def triple_quantile_loss(nout, lowquantile, highquantile):
742738
"""This function computes the quantile loss for the median and low and high quantiles. The median is given twice the weight of the other components.
743739
744740
Parameters
@@ -776,7 +772,7 @@ def loss(y_true, y_pred):
776772
return loss
777773

778774

779-
def quantile_metric(nout: int, index: int, quantile: float):
775+
def quantile_metric(nout, index, quantile):
780776
"""This function computes the quantile metric for a given quantile and corresponding output index. This is provided as a metric to track evolution while training.
781777
782778
Parameters
@@ -813,7 +809,7 @@ def metric(y_true, y_pred):
813809
# For the Contamination Model
814810

815811

816-
def add_index_to_output(y_train: Array) -> Array:
812+
def add_index_to_output(y_train):
817813
""" This function adds a column to the training output to store the indices of the corresponding samples in the training set.
818814
819815
Parameters
@@ -832,7 +828,7 @@ def add_index_to_output(y_train: Array) -> Array:
832828
return y_train_augmented
833829

834830

835-
def contamination_loss(nout: int, T_k, a, sigmaSQ, gammaSQ):
831+
def contamination_loss(nout, T_k, a, sigmaSQ, gammaSQ):
836832
""" Function to compute contamination loss. It is composed by two terms: (i) the loss with respect to the normal distribution that models the distribution of the training data samples, (ii) the loss with respect to the Cauchy distribution that models the distribution of the outlier samples. Note that the evaluation of this contamination loss function does not make sense for any data different to the training set. This is because latent variables are only defined for samples in the training set.
837833
838834
Parameters
@@ -917,7 +913,7 @@ def __init__(self, x, y, a_max=0.99):
917913
self.sigmaSQvalues = [] # array to store sigmaSQ evolution
918914
self.gammaSQvalues = [] # array to store gammaSQ evolution
919915

920-
def on_epoch_end(self, epoch: int, logs={}):
916+
def on_epoch_end(self, epoch, logs={}):
921917
""" Updates the parameters of the distributions in the contamination model on epoch end. The parameters updated are: 'a' for the global weight of the membership to the normal distribution, 'sigmaSQ' for the variance of the normal distribution and 'gammaSQ' for the scale of the Cauchy distribution of outliers. The latent variables are updated as well: 'T_k' describing in the first column the probability of membership to normal distribution and in the second column probability of membership to the Cauchy distribution i.e. outlier. Stores evolution of global parameters (a, sigmaSQ and gammaSQ).
922918
923919
Parameters
@@ -966,7 +962,7 @@ def on_epoch_end(self, epoch: int, logs={}):
966962
self.gammaSQvalues.append(gammaSQ_eval)
967963

968964

969-
def mse_contamination_metric(nout: int):
965+
def mse_contamination_metric(nout):
970966
"""This function computes the mean squared error (mse) for the contamination model. The mse is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
971967
972968
Parameters
@@ -990,7 +986,7 @@ def metric(y_true, y_pred):
990986
return metric
991987

992988

993-
def mae_contamination_metric(nout: int):
989+
def mae_contamination_metric(nout):
994990
"""This function computes the mean absolute error (mae) for the contamination model. The mae is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
995991
996992
Parameters
@@ -1014,7 +1010,7 @@ def metric(y_true, y_pred):
10141010
return metric
10151011

10161012

1017-
def r2_contamination_metric(nout: int):
1013+
def r2_contamination_metric(nout):
10181014
"""This function computes the r2 for the contamination model. The r2 is computed over the prediction. Therefore, the augmentation for the index variable is ignored.
10191015
10201016
Parameters

0 commit comments

Comments
 (0)