Skip to content

refactor: simplifications in ModuleMeta __call__#999

Closed
nstarman wants to merge 13 commits intopatrick-kidger:devfrom
nstarman:module-call
Closed

refactor: simplifications in ModuleMeta __call__#999
nstarman wants to merge 13 commits intopatrick-kidger:devfrom
nstarman:module-call

Conversation

@nstarman
Copy link
Copy Markdown
Contributor

No description provided.

@nstarman nstarman force-pushed the module-call branch 3 times, most recently from 40d6df7 to 966b589 Compare April 15, 2025 20:28
@nstarman
Copy link
Copy Markdown
Contributor Author

nstarman commented Apr 15, 2025

@patrick-kidger I rebased on dev but the target is still main, which I don't believe I can switch without closing this PR and opening a new one. Hence the pickup of other commits.
Cool. Didn't know you could switch targets post PR creation.

@nstarman nstarman changed the base branch from main to dev April 23, 2025 16:53
@nstarman nstarman marked this pull request as ready for review April 30, 2025 03:32
@nstarman
Copy link
Copy Markdown
Contributor Author

@patrick-kidger I think this is ready for review.

nstarman and others added 12 commits May 16, 2025 17:30
* feat: converter = None as the sentinel

1. Easier for users to access instead of a private sentinel.
2. simplifies later performance-related logic changes.
3. Broadens support to allow `dataclasses.field(metadata=dict(converter=None))`, not just `eqx.field`.

Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>

* refactor: only set annotations if datclass init

Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>

---------

Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
* docs: add diffraxtra

Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>

* docs: link to list

Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>

---------

Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
…ra. This is a bit of a technical nit, but the interface isn't what I would describe as being OOP: which to me refers specifically to mutating state via methods. (If the state is immutable it's a curried function and is FP.)
- Now tracking only the running statistics, not the zero-debiased statistics. These are handled at inference time instead.
- Standarised bibtex formatting.
- Moved `mode` argument to the end for backward compatibility.
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
@patrick-kidger
Copy link
Copy Markdown
Owner

So for this one I'm actually planning on doing my Module refactor fairly shortly. (As my private testing with it has so far gone smoothly.) For that reason I'm inclined not to change the internals here, and instead tag you on the PR when I write it? It'd be great to have your feedback on whether the new version looks up to scratch.

Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
@nstarman nstarman marked this pull request as draft May 16, 2025 16:19
@nstarman
Copy link
Copy Markdown
Contributor Author

nstarman commented May 16, 2025

Some possible ideas, inspired from dataclassish

  1. Always define a __dataclass_init__ method that is constructed by Module and which is the init method created by dataclass if init=True. This is very useful when defining a custom init method to be able to do
class MyClass
    def __init__(self, ...):
        ... # Whatever manual processing
        self.__dataclass_init__(...)  # don't have to worry about setting in frozen vs not, descriptors vs not, etc.

Of course it's totally optional to use, it's just very convenient!
One of my common gripes when writing an __init__ with a dataclass is that I just needed to do a bit of manual processing but now I have to write all the attribute assignments.

  1. Always define a __converter_init__ method that is constructed by Module and which does
def __converter_init__(self, ...):
    ba = run_converters(...)
    self.__dataclass_init__(**ba)

When dataclass_init=True this is swapped in for __init__ so that input conversion happens early, which should be more efficient (e.g. saving 2x assignment).
When dataclass_init=False then users can simplify their __init__ methods by doing

class MyClass
    def __init__(self, ...):
        ... # Whatever manual processing
        self. __converter_init__(...)  # handles conversion, and all complexities of setting.

Of course it's totally optional to use, it's just very convenient!

  1. If dataclass_init is True and the run_converters skip args that are Tracers then the __init__ method satisfies JAX's assumptions about dataclasses and we can use the faster pytree registration.

@patrick-kidger
Copy link
Copy Markdown
Owner

These are an interesting set of ideas! I can totally see why these are attractive for power users.

Points (1) and (2) I think probably impose too great a learning curve for new users: we already have converter, __post_init__, __init__ and __check_init__, which is already rather a lot. These methods might save a few lines but I think they increase the complexity too much.

Point (3) sounds like an internal optimisation we could probably perform, however. I like the sound of that one.

@nstarman
Copy link
Copy Markdown
Contributor Author

SGTM. I think we can also do the converter loop as part of the __init__ trick when dataclass_init=True as an internal optimization. That should reduce setting from 2x to 1x.

@patrick-kidger patrick-kidger deleted the branch patrick-kidger:dev July 7, 2025 20:36
@nstarman nstarman deleted the module-call branch July 7, 2025 22:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants