feat(nnx): add object-oriented pooling layers and GlobalAveragePool #5201
+160
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR adds native object-oriented pooling layers to
flax.nnx.I am currently migrating from Keras to Flax NNX and noticed that while functional pooling exists in
linen, NNX was missing the modular, object-oriented equivalents (e.g.,nnx.AvgPool,nnx.MaxPool). Currently, users have to import functional pooling fromlinenor wrap them manually, which breaks the object-oriented flow of NNX.Additionally, I noticed that
GlobalAveragePoolwas missing entirely, so I have implemented it as a standard NNX module to simplify workflows for those coming from other frameworks like Keras or PyTorch.Changes Made
flax/nnx/nn/pooling.py: ImplementedMaxPool,MinPool,AvgPool, andGlobalAveragePoolasnnx.Modulesubclasses.flax/nnx/__init__.py: Exposed the new layers in the top-levelnnxnamespace.pooling.rstto the API reference and updated the index.Usage Example
Before (Linen Functional Pooling)
After (NNX Object-Oriented Pooling)
**Closes #5202 **