-
Notifications
You must be signed in to change notification settings - Fork 787
Open
Description
Problem Description
Currently, flax.nnx lacks native object-oriented pooling modules such as MaxPool, AvgPool, and GlobalAveragePool. Users migrating from frameworks like Keras or PyTorch—or even transitioning from flax.linen—are forced to mix functional API calls within the object-oriented NNX structure. This creates an inconsistent developer experience and requires manual boilerplate for common operations like Global Average Pooling.
Proposed Feature
Introduce a dedicated pooling module suite within nnx that mirrors the ergonomic design of other NNX layers. This includes:
- Subsampling Modules:
MaxPool,AvgPool, andMinPool. - Global Pooling: A dedicated
GlobalAveragePoolmodule to replace manualjnp.meancalls.
Implementation Status
I have already implemented these modules and exposed them in the nnx namespace.
See Pull Request: #5201
Justification & Benefits
- API Consistency: Maintains the OO-flow of NNX without jumping back into
linen.functional. - Framework Parity: Lowers the barrier for users migrating from Keras/PyTorch.
- Readability: Simplifies model definitions, especially for standard CNN architectures.
Metadata
Metadata
Assignees
Labels
No labels