-
Notifications
You must be signed in to change notification settings - Fork 54
Description
Bug description
When using model.evaluate() after model.fit(..., validation_data=...), the model.evaluate() accuracy metrics are much higher than expected.
That happens only the model is compiled with run_eagerly=False.
How to reproduce
1 - Download the sample dataset
2 - Run this code and will you get pretty good ranking metrics from model.evaluate(), something like
EVALUATION METRICS: auc: 0.9199 - ndcg_10: 0.7043 - recall_at_10: 0.7192 - recall_at_1000: 0.7871 - loss: 0.0000e+00 - regularization_loss: 0.0805 - total_loss: 0.0805
3 - Run this code again with either of the following changes. The eval accuracy is going to drop now as follows
- Comment
validation_data=test_dlwithinmodel.fit()OR - Set
run_eagerly=True.
EVALUATION METRICS: auc: 0.6208 - ndcg_10: 0.0023 - recall_at_10: 0.0049 - recall_at_1000: 0.1804 - loss: 0.0000e+00 - regularization_loss: 0.0800 - total_loss: 0.0800
In order to verify the model.evaluate() results, I use topk_rec = model.to_top_k_recommender() and topk_rec.predict() to generate topk predictions and use numpy to compute Recall@10 and Recall@1000 metrics.
The results are much smaller than metrics output from model.evaluate()
PREDICT TOP-K - RECALL@10 = 0.00013831947372206638 - RECALL@1000 = 0.016554074615056903
Code to reproduce
import tensorflow as tf
tf.keras.utils.set_random_seed(52)
import cupy
import merlin.models.tf as mm
import merlin.models.tf.dataset as tf_dataloader
import numpy as np
from merlin.io.dataset import Dataset
from merlin.models.utils.dataset import unique_rows_by_features
from merlin.schema import Schema
from merlin.schema.io.tensorflow_metadata import TensorflowMetadata
from merlin.schema.tags import Tags
from nvtabular.ops import *
from tensorflow.keras import regularizers
data_path = "./data/processed"
schema = TensorflowMetadata.from_proto_text_file(
data_path + "/train/"
).to_merlin_schema()
schema = schema.select_by_tag([Tags.ITEM_ID, Tags.USER_ID, Tags.ITEM, Tags.USER])
model = mm.TwoTowerModel(
schema,
query_tower=mm.MLPBlock(
[128, 64],
no_activation_last_layer=True,
kernel_regularizer=regularizers.l2(1e-2),
bias_regularizer=regularizers.l2(1e-2),
),
samplers=[mm.InBatchSampler()],
embedding_options=mm.EmbeddingOptions(infer_embedding_sizes=True),
metrics=[
tf.keras.metrics.AUC(from_logits=True),
mm.NDCGAt(10),
mm.RecallAt(10),
mm.RecallAt(1000),
],
loss="categorical_crossentropy",
logits_temperature=0.1,
)
batch_size = 4096
train_dl = tf_dataloader.BatchedDataset(
Dataset(data_path + "/train/*.parquet", part_size="500MB", schema=schema),
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
test_dl = tf_dataloader.BatchedDataset(
Dataset(data_path + "/test/*.parquet", part_size="500MB", schema=schema),
batch_size=batch_size,
shuffle=False,
drop_last=False,
)
from datetime import datetime
logdir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=logdir, update_freq=10, write_steps_per_second=True
)
model.set_retrieval_candidates_for_evaluation(train_dl.data)
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(opt, run_eagerly=False)
model.fit(
train_dl,
epochs=10,
validation_data=test_dl,
validation_steps=20,
callbacks=[tensorboard_callback],
)
############ EVALUATION ############
eval_metrics = model.evaluate(test_dl)
print("EVALUATION METRICS: ", eval_metrics)
############ CHECKING EVALUATION USING TOP-K PREDICTIONS ############
item_dataset = unique_rows_by_features(train_dl.data, Tags.ITEM, Tags.ITEM_ID).to_ddf()
topk_recommender = model.to_top_k_recommender(Dataset(item_dataset), k=1000)
topk_output = topk_recommender.predict(test_dl)
topk_predictions, topk_items = topk_output
test_df = test_dl.data.to_ddf()
positive_item_ids = cupy.asnumpy(test_df["item_id"].compute().values)
def numpy_recall(labels, top_item_ids, k):
return np.equal(np.expand_dims(labels, -1), top_item_ids[:, :k]).max(axis=-1).mean()
recall_at_10 = numpy_recall(positive_item_ids, topk_items, k=10)
recall_at_1000 = numpy_recall(positive_item_ids, topk_items, k=1000)
print(f"PREDICT TOP-K - RECALL@10 = {recall_at_10} - RECALL@1000 = {recall_at_1000}")