Skip to content

feat(nnx): Missing object-oriented pooling layers in NNX #5202

@starryendymion

Description

@starryendymion

Problem Description

Currently, flax.nnx lacks native object-oriented pooling modules such as MaxPool, AvgPool, and GlobalAveragePool. Users migrating from frameworks like Keras or PyTorch—or even transitioning from flax.linen—are forced to mix functional API calls within the object-oriented NNX structure. This creates an inconsistent developer experience and requires manual boilerplate for common operations like Global Average Pooling.

Proposed Feature

Introduce a dedicated pooling module suite within nnx that mirrors the ergonomic design of other NNX layers. This includes:

  • Subsampling Modules: MaxPool, AvgPool, and MinPool.
  • Global Pooling: A dedicated GlobalAveragePool module to replace manual jnp.mean calls.

Implementation Status

I have already implemented these modules and exposed them in the nnx namespace.
See Pull Request: #5201

Justification & Benefits

  1. API Consistency: Maintains the OO-flow of NNX without jumping back into linen.functional.
  2. Framework Parity: Lowers the barrier for users migrating from Keras/PyTorch.
  3. Readability: Simplifies model definitions, especially for standard CNN architectures.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions