fix:integerlookup one_hot shape inference for 2D inputs#22592
fix:integerlookup one_hot shape inference for 2D inputs#22592maitry63 wants to merge 2 commits intokeras-team:masterfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a shape collapse issue in the IndexLookup layer when using one_hot output mode with 2D symbolic inputs. The changes update compute_output_shape and compute_output_spec to correctly preserve input dimensions and add regression tests for these scenarios. The review feedback identifies significant code duplication between these two methods regarding depth calculation and one_hot logic, recommending a refactor into a shared helper method to improve code maintainability.
| if self.output_mode == "one_hot": | ||
| depth = ( | ||
| self.max_tokens | ||
| if self.pad_to_max_tokens and self.max_tokens is not None | ||
| else self.vocabulary_size() | ||
| ) | ||
| output_shape = input_shape + (depth,) | ||
| else: | ||
| output_shape = self.compute_output_shape(input_shape) |
There was a problem hiding this comment.
There is some code duplication between compute_output_spec and compute_output_shape that could be refactored to improve maintainability.
- The calculation of
depthis identical in both methods (here and in lines 561-564). This could be extracted to a private helper method to avoid redundancy. - The logic for
one_hotoutput shape calculation is also duplicated. It seemscompute_output_shapeis now correct for all cases. If possible,compute_output_speccould be simplified to callcompute_output_shapefor all modes except'int', which would remove the duplication.
If this duplication is intentional and necessary for the fix to work correctly in all execution modes, adding a comment explaining the reason would be very helpful for future maintenance.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #22592 +/- ##
=======================================
Coverage 83.30% 83.30%
=======================================
Files 596 596
Lines 67962 67960 -2
Branches 10580 10578 -2
=======================================
Hits 56615 56615
+ Misses 8600 8599 -1
+ Partials 2747 2746 -1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Hi @maitry63, Can you check with this once, looks like outdated. Thanks ! |
hertschuh
left a comment
There was a problem hiding this comment.
Most changes are spurious formatting changes. Please undo all of the these to keep only the actual changes.
Please also rebase.
Thanks!
This PR fixes a regression in IntegerLookup with
output_mode="one_hot"where 2D inputs(batch, sequence_length)were incorrectly producing output shapes in symbolic mode.Fixes: #22520