Skip to content

nnx.Optimizer.update should return the updates #5182

@am001122

Description

@am001122

At https://github.com/google/flax/blob/main/flax/nnx/training/optimizer.py#L221 it would be nice if nnx.Optimizer.update could return the underlying updates PyTree from optax. All existing callers are free to ignore it, but this would allow logging and diagnostics on the weight update in a straightforward way without the caller needing to drop down to pure optax.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions