Skip to content

Conversation

@starryendymion
Copy link

@starryendymion starryendymion commented Jan 26, 2026

Description

This PR adds native object-oriented pooling layers to flax.nnx.

I am currently migrating from Keras to Flax NNX and noticed that while functional pooling exists in linen, NNX was missing the modular, object-oriented equivalents (e.g., nnx.AvgPool, nnx.MaxPool). Currently, users have to import functional pooling from linen or wrap them manually, which breaks the object-oriented flow of NNX.

Additionally, I noticed that GlobalAveragePool was missing entirely, so I have implemented it as a standard NNX module to simplify workflows for those coming from other frameworks like Keras or PyTorch.

Changes Made

  • Added flax/nnx/nn/pooling.py: Implemented MaxPool, MinPool, AvgPool, and GlobalAveragePool as nnx.Module subclasses.
  • Updated flax/nnx/__init__.py: Exposed the new layers in the top-level nnx namespace.
  • Documentation: Added pooling.rst to the API reference and updated the index.

Usage Example

Before (Linen Functional Pooling)

from flax import linen as nn

# Max Pool
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

# Average Pool
x = nn.avg_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")

# Global Average Pool (manual)
x = x.mean(axis=(1, 2))

After (NNX Object-Oriented Pooling)

from flax import nnx

# Max Pool
x = nnx.MaxPool(window_shape=(2, 2), strides=(2, 2))(x)

# Average Pool
x = nnx.AvgPool(window_shape=(3, 3), strides=(2, 2), padding="SAME")(x)

# Global Average Pool
x = nnx.GlobalAveragePool()(x)

**Closes #5202 **

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat(nnx): Missing object-oriented pooling layers in NNX

1 participant