-
Notifications
You must be signed in to change notification settings - Fork 54
Description
Bug description
The the TF model is compiled in graph mode, abd model.evaluate() is called by model.fit() when the validation_data is set for fit(), running validation at the end of each epoch.
But when model.evaluate() is explicitly called by the user after model.fit(validation_data=...) an error is raised.
Steps/Code to reproduce bug
- If this PR is not merged yet, use its branch
metrics_opt3to reproduce the issue - Uncomment these lines on test_two_tower_retrieval_model_with_metrics and run it with
run_eagerly=True - The explicitly call to
model.evaluate()in the end of the test will fail
Expected behavior
It should compute both validation metrics during training and also when model.evaluate() is called
Additional context
I believe the issue happens because when we call model.fit() the ItemRetrievalTask._pre_eval_topk is not defined yet. The _pre_eval_topk is only set after model.fit(), when the model.model.load_topk_evaluation() is called (like here) to use the items dataset and the trained item tower.
The issue is that when model.fit(validation_data=...) is called first (and internally model.evaluation()), as _pre_eval_topk is frozen in the graph as None instead of being TopKIndexBlock, it is not included in the evaluation graph failing in the explicit model.evaluate().
Potential solution
A potential solution could be setting _pre_eval_topk=TopKIndexBlock() automatically for the ItemRetrievalTask, with dynamic variables ids and values with the first dim undefined, and than loading the data into the index after fit using something like variable.assign()