-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Labels
bugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed
Description
After migrating to tf2 this method doesn't work
Emgraph/emgraph/models/EmbeddingModel.py
Lines 1927 to 2135 in 3926ad7
| def _calibrate(self, X_pos, X_neg=None, positive_base_rate=None, batches_count=100, epochs=50): | |
| """Calibrate predictions tod o: un-underscore this method later | |
| The method implements the heuristics described in :cite:`calibration`, | |
| using Platt scaling :cite:`platt1999probabilistic`. | |
| The calibrated predictions can be obtained with :meth:`predict_proba` | |
| after calibration is done. | |
| Ideally, calibration should be performed on a validation set that was not used to train the embeddings. | |
| There are two modes of operation, depending on the availability of negative triples: | |
| #. Both positive and negative triples are provided via ``X_pos`` and ``X_neg`` respectively. \ | |
| The optimization is done using a second-order method (limited-memory BFGS), \ | |
| therefore no hyperparameter needs to be specified. | |
| #. Only positive triples are provided, and the negative triples are generated by corruptions \ | |
| just like it is done in training or evaluation. The optimization is done using a first-order method (ADAM), \ | |
| therefore ``batches_count`` and ``epochs`` must be specified. | |
| Calibration is highly dependent on the base rate of positive triples. | |
| Therefore, for mode (2) of operation, the user is required to provide the ``positive_base_rate`` argument. | |
| For mode (1), that can be inferred automatically by the relative sizes of the positive and negative sets, | |
| but the user can override that by providing a value to ``positive_base_rate``. | |
| Defining the positive base rate is the biggest challenge when calibrating without negatives. That depends on | |
| the user choice of which triples will be evaluated during test time. | |
| Let's take WN11 as an example: it has around 50% positives triples on both the validation set and test set, | |
| so naturally the positive base rate is 50%. However, should the user resample it to have 75% positives | |
| and 25% negatives, its previous calibration will be degraded. The user must recalibrate the model now with a | |
| 75% positive base rate. Therefore, this parameter depends on how the user handles the dataset and | |
| cannot be determined automatically or a priori. | |
| .. Note :: | |
| Incompatible with large graph mode (i.e. if ``self.dealing_with_large_graphs=True``). | |
| :param X_pos: Numpy array of positive triples. | |
| :type X_pos: ndarray (shape [n, 3]) | |
| :param X_neg: Numpy array of negative triples. | |
| If `None`, the negative triples are generated via corruptions | |
| and the user must provide a positive base rate instead. | |
| :type X_neg: ndarray (shape [n, 3]) | |
| :param positive_base_rate: Base rate of positive statements. | |
| For example, if we assume there is a fifty-fifty chance of any query to be true, the base rate would be 50%. | |
| If ``X_neg`` is provided and this is `None`, the relative sizes of ``X_pos`` and ``X_neg`` will be used to | |
| determine the base rate. For example, if we have 50 positive triples and 200 negative triples, | |
| the positive base rate will be assumed to be 50/(50+200) = 1/5 = 0.2. | |
| This must be a value between 0 and 1. | |
| :type positive_base_rate: float | |
| :param batches_count: Number of batches to complete one epoch of the Platt scaling training. | |
| Only applies when ``X_neg`` is `None`. | |
| :type batches_count: int | |
| :param epochs: Number of epochs used to train the Platt scaling model. | |
| Only applies when ``X_neg`` is `None`. | |
| :type epochs: int | |
| :return: | |
| :rtype: | |
| Examples: | |
| >>> import numpy as np | |
| >>> from sklearn.metrics import brier_score_loss, log_loss | |
| >>> from scipy.special import expit | |
| >>> | |
| >>> from emgraph.datasets import load_wn11 | |
| >>> from emgraph.models import TransE | |
| >>> | |
| >>> X = load_wn11() | |
| >>> X_valid_pos = X['valid'][X['valid_labels']] | |
| >>> X_valid_neg = X['valid'][~X['valid_labels']] | |
| >>> | |
| >>> model = TransE(batches_count=64, seed=0, epochs=500, k=100, eta=20, | |
| >>> optimizer='adam', optimizer_params={'lr':0.0001}, | |
| >>> loss='pairwise', verbose=True) | |
| >>> | |
| >>> model.fit(X['train']) | |
| >>> | |
| >>> # Raw scores | |
| >>> scores = model.predict(X['test']) | |
| >>> | |
| >>> # Calibrate with positives and negatives | |
| >>> model.calibrate(X_valid_pos, X_valid_neg, positive_base_rate=None) | |
| >>> probas_pos_neg = model.predict_proba(X['test']) | |
| >>> | |
| >>> # Calibrate with just positives and base rate of 50% | |
| >>> model.calibrate(X_valid_pos, positive_base_rate=0.5) | |
| >>> probas_pos = model.predict_proba(X['test']) | |
| >>> | |
| >>> # Calibration evaluation with the Brier score loss (the smaller, the better) | |
| >>> print("Brier scores") | |
| >>> print("Raw scores:", brier_score_loss(X['test_labels'], expit(scores))) | |
| >>> print("Positive and negative calibration:", brier_score_loss(X['test_labels'], probas_pos_neg)) | |
| >>> print("Positive only calibration:", brier_score_loss(X['test_labels'], probas_pos)) | |
| Brier scores | |
| Raw scores: 0.4925058891371126 | |
| Positive and negative calibration: 0.20434617882733366 | |
| Positive only calibration: 0.22597599585144656 | |
| """ | |
| if not self.is_fitted: | |
| msg = 'Model has not been fitted.' | |
| logger.error(msg) | |
| raise RuntimeError(msg) | |
| if self.dealing_with_large_graphs: | |
| msg = "Calibration is incompatible with large graph mode." | |
| logger.error(msg) | |
| raise ValueError(msg) | |
| if positive_base_rate is not None and (positive_base_rate <= 0 or positive_base_rate >= 1): | |
| msg = "positive_base_rate must be a value between 0 and 1." | |
| logger.error(msg) | |
| raise ValueError(msg) | |
| dataset_handle = None | |
| try: | |
| # tf.reset_default_graph() | |
| self.rnd = check_random_state(self.seed) | |
| # tf.random.set_random_seed(self.seed) | |
| tf.random.set_seed(self.seed) | |
| self._load_model_from_trained_params() | |
| if X_neg is not None: | |
| if positive_base_rate is None: | |
| positive_base_rate = len(X_pos) / (len(X_pos) + len(X_neg)) | |
| scores_pos, scores_neg = self._calibrate_with_negatives(X_pos, X_neg) | |
| else: | |
| if positive_base_rate is None: | |
| msg = "When calibrating with randomly generated negative corruptions, " \ | |
| "`positive_base_rate` must be set to a value between 0 and 1." | |
| logger.error(msg) | |
| raise ValueError(msg) | |
| scores_pos, scores_neg, dataset_handle = self._calibrate_with_corruptions(X_pos, batches_count) | |
| n_pos = len(X_pos) | |
| n_neg = len(X_neg) if X_neg is not None else n_pos | |
| scores_tf = tf.concat([scores_pos, scores_neg], axis=0) | |
| labels = tf.concat([tf.cast(tf.fill(tf.shape(scores_pos), (n_pos + 1.0) / (n_pos + 2.0)), tf.float32), | |
| tf.cast(tf.fill(tf.shape(scores_neg), 1 / (n_neg + 2.0)), tf.float32)], | |
| axis=0) | |
| # Platt scaling model | |
| # w = tf.get_variable('w', initializer=0.0, dtype=tf.float32) | |
| # b = tf.get_variable('b', initializer=np.log((n_neg + 1.0) / (n_pos + 1.0)).astype(np.float32), | |
| # dtype=tf.float32) | |
| # w = tf.Variable(tf.constant_initializer(0.0, shape=[scores_tf.shape]), name='w', dtype=tf.float32) | |
| w = self.make_variable(name='w', | |
| shape=scores_tf.shape, | |
| initializer=tf.zeros_initializer(), | |
| dtype=tf.float32) | |
| # w = self._make_variable(name='w', shape=tf.TensorShape(None), initializer=tf.zeros_initializer(), dtype=tf.float32) | |
| print("np.log((n_neg + 1.0) / (n_pos + 1.0)).astype(np.float32): ", tf.constant_initializer(np.log((n_neg + 1.0) / (n_pos + 1.0)).astype(np.float32))) | |
| b = self.make_variable(name='b', | |
| shape=scores_tf.shape, | |
| initializer=tf.constant_initializer(np.log((n_neg + 1.0) / (n_pos + 1.0)).astype(np.float32)), | |
| dtype=tf.float32) | |
| print(f"w: {w}\ntf.stop_gradient(scores_tf): {tf.stop_gradient(scores_tf)}\nb: {b}") | |
| # logits = -(w * tf.stop_gradient(scores_tf) + b) | |
| logits = -(w * scores_tf + b) | |
| # Sample weights make sure the given positive_base_rate will be achieved irrespective of batch sizes | |
| weigths_pos = tf.size(scores_neg) / tf.size(scores_pos) | |
| weights_neg = (1.0 - positive_base_rate) / positive_base_rate | |
| weights = tf.concat([tf.cast(tf.fill(tf.shape(scores_pos), weigths_pos), tf.float32), | |
| tf.cast(tf.fill(tf.shape(scores_neg), weights_neg), tf.float32)], axis=0) | |
| print("w: ", w, "\nweights: ", weights) | |
| # loss = functools.partial(tf.compat.v1.losses.sigmoid_cross_entropy, labels, logits, weights=weights) | |
| loss = functools.partial(tf.nn.sigmoid_cross_entropy_with_logits, labels, logits) | |
| # optimizer = tf.train.AdamOptimizer() | |
| optimizer = tf.keras.optimizers.Adam() | |
| train = optimizer.minimize(loss, [logits, labels]) | |
| # with tf.Session(config=self.tf_config) as sess: | |
| # sess.run(tf.global_variables_initializer()) | |
| epoch_iterator_with_progress = tqdm(range(1, epochs + 1), disable=(not self.verbose), unit='epoch') | |
| for _ in epoch_iterator_with_progress: | |
| losses = [] | |
| for batch in range(batches_count): | |
| # loss_batch, _ = sess.run([loss, train]) | |
| loss_batch, _ = [loss, train] | |
| losses.append(loss_batch) | |
| if self.verbose: | |
| msg = 'Calibration Loss: {:10f}'.format(sum(losses) / batches_count) | |
| logger.debug(msg) | |
| epoch_iterator_with_progress.set_description(msg) | |
| # self.calibration_parameters = sess.run([w, b]) | |
| self.calibration_parameters = [w, b] | |
| self.is_calibrated = True | |
| finally: | |
| if dataset_handle is not None: | |
| dataset_handle.cleanup() |
https://github.com/bi-graph/Emgraph/blob/master/emgraph/models/EmbeddingModel.py
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed