Skip to content

[BUG] When using model.fit(... validation_data=...) explicit calls to model.evaluate() fails after (in graph mode) #233

@gabrielspmoreira

Description

@gabrielspmoreira

Bug description

The the TF model is compiled in graph mode, abd model.evaluate() is called by model.fit() when the validation_data is set for fit(), running validation at the end of each epoch.
But when model.evaluate() is explicitly called by the user after model.fit(validation_data=...) an error is raised.

Steps/Code to reproduce bug

  1. If this PR is not merged yet, use its branch metrics_opt3 to reproduce the issue
  2. Uncomment these lines on test_two_tower_retrieval_model_with_metrics and run it with run_eagerly=True
  3. The explicitly call to model.evaluate() in the end of the test will fail

Expected behavior

It should compute both validation metrics during training and also when model.evaluate() is called

Additional context

I believe the issue happens because when we call model.fit() the ItemRetrievalTask._pre_eval_topk is not defined yet. The _pre_eval_topk is only set after model.fit(), when the model.model.load_topk_evaluation() is called (like here) to use the items dataset and the trained item tower.

The issue is that when model.fit(validation_data=...) is called first (and internally model.evaluation()), as _pre_eval_topk is frozen in the graph as None instead of being TopKIndexBlock, it is not included in the evaluation graph failing in the explicit model.evaluate().

Potential solution

A potential solution could be setting _pre_eval_topk=TopKIndexBlock() automatically for the ItemRetrievalTask, with dynamic variables ids and values with the first dim undefined, and than loading the data into the index after fit using something like variable.assign()

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions