Conversation
|
Important Installation incomplete: to start using Gemini Code Assist, please ask the organization owner(s) to visit the Gemini Code Assist Admin Console and sign the Terms of Services. |
There was a problem hiding this comment.
Pull request overview
This is a major refactoring PR that renames the project from "linax" to "discretax" and removes the configuration-based architecture in favor of direct instantiation with Cfg annotations for config system compatibility.
Key Changes:
- Project renamed from "linax" to "discretax" across all files
- Removed all
*Configclasses and theirbuild()methods - Introduced
Cfgtype annotation for parameters that can be initialized via config systems (e.g., Hydra) - Models now compose with
eqx.nn.Sequentialinstead of being monolithic - Tests updated to reflect new direct instantiation pattern
Reviewed changes
Copilot reviewed 80 out of 84 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| uv.lock | Package name changed from "linax" to "discretax" |
| pyproject.toml | Project metadata updated with new name and URLs |
| README.md | Documentation updated with new project name and examples |
| mkdocs.yml | Documentation configuration updated, strict mode set to false |
| src/discretax/utils/config_mixin.py | New config system with Cfg annotation and PartialLoaderMixin |
| src/discretax/models/*.py | Models refactored to direct instantiation, compose with Sequential |
| src/discretax/**/base.py | Base classes updated, removed config classes, added PartialLoaderMixin |
| src/discretax/**/*.py | All components updated to new pattern without configs |
| tests/*.py | Tests updated to use direct instantiation instead of configs |
| docs/ | All API documentation updated to reflect new structure |
| docs/examples/ | Example notebook updated to show new usage patterns |
Comments suppressed due to low confidence (10)
src/discretax/heads/classification.py:14
- This class does not call Head.init during initialization. (ClassificationHead.init may be missing a call to a base class init)
src/discretax/channel_mixers/glu.py:17 - This class does not call ChannelMixer.init during initialization. (GLU.init may be missing a call to a base class init)
src/discretax/sequence_mixers/linoss.py:23 - This class does not call SequenceMixer.init during initialization. (LinOSSSequenceMixer.init may be missing a call to a base class init)
src/discretax/sequence_mixers/lru.py:36 - This class does not call SequenceMixer.init during initialization. (LRUSequenceMixer.init may be missing a call to a base class init)
src/discretax/channel_mixers/mlp.py:53 - This class does not call ChannelMixer.init during initialization. (MLPChannelMixer.init may be missing a call to a base class init)
src/discretax/heads/regression.py:13 - This class does not call Head.init during initialization. (RegressionHead.init may be missing a call to a base class init)
src/discretax/sequence_mixers/s4d.py:19 - This class does not call SequenceMixer.init during initialization. (S4DSequenceMixer.init may be missing a call to a base class init)
src/discretax/sequence_mixers/s5.py:24 - This class does not call SequenceMixer.init during initialization. (S5SequenceMixer.init may be missing a call to a base class init)
src/discretax/blocks/standard.py:16 - This class does not call Block.init during initialization. (StandardBlock.init may be missing a call to a base class init)
src/discretax/channel_mixers/swi_glu.py:21 - This class does not call ChannelMixer.init during initialization. (SwiGLU.init may be missing a call to a base class init)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/discretax/encoder/embedding.py
Outdated
| from discretax.utils.config_mixin import Cfg | ||
|
|
||
|
|
||
| class EmbeddingEncoder(Encoder): |
There was a problem hiding this comment.
This class does not call Encoder.init during initialization. (EmbeddingEncoder.init may be missing a call to a base class init)
| from discretax.sequence_mixers.base import SequenceMixer | ||
|
|
||
|
|
||
| class IdentitySequenceMixer(SequenceMixer): |
There was a problem hiding this comment.
This class does not call SequenceMixer.init during initialization. (IdentitySequenceMixer.init may be missing a call to a base class init)
| from discretax.channel_mixers.base import ChannelMixer | ||
|
|
||
|
|
||
| class IdentityChannelMixer(ChannelMixer): |
There was a problem hiding this comment.
This class does not call ChannelMixer.init during initialization. (IdentityChannelMixer.init may be missing a call to a base class init)
There was a problem hiding this comment.
@benedict-armstrong does this comment make sense?
src/discretax/encoder/linear.py
Outdated
| from discretax.utils.config_mixin import Cfg | ||
|
|
||
|
|
||
| class LinearEncoder(Encoder): |
There was a problem hiding this comment.
This class does not call Encoder.init during initialization. (LinearEncoder.init may be missing a call to a base class init)
francescoshox
left a comment
There was a problem hiding this comment.
Jus checked the doc for now sorry
assets/logo.png
Outdated
There was a problem hiding this comment.
Sure! I will do that once I am back in Tüb
src/discretax/utils/config_mixin.py
Outdated
|
|
||
|
|
||
| def register(name: str): | ||
| """Register a class in the registry.""" |
There was a problem hiding this comment.
I know this is nitpick, bu we should say what the args are.
src/discretax/utils/config_mixin.py
Outdated
|
|
||
|
|
||
| def build_from_config(cfg: Any) -> Any: | ||
| """Recursively transforms a config dict into a generic Class Factory (partial).""" |
There was a problem hiding this comment.
Same for args here. Given this is a library i think it is always good practice saying
Args:
cfg: general something config
src/discretax/utils/config_mixin.py
Outdated
| """Mixin to load a class from a config.""" | ||
|
|
||
| @classmethod | ||
| def from_config(cls, cfg: dict | Any) -> functools.partial: |
There was a problem hiding this comment.
is the output type correct?
Usually partial returns a specifc object partially instantiated
There was a problem hiding this comment.
The output type can be a callable returning an instanceof a class.
src/discretax/utils/config_mixin.py
Outdated
| """Mixin to load a class from a config.""" | ||
|
|
||
| @classmethod | ||
| def from_config(cls, cfg: dict | Any) -> functools.partial: |
There was a problem hiding this comment.
dict | Any i don t like too much. Any already includes dict.
src/discretax/utils/config_mixin.py
Outdated
|
|
||
|
|
||
| class PartialLoaderMixin: | ||
| """Mixin to load a class from a config.""" |
There was a problem hiding this comment.
Maybe here we can add a small example of what this does. To me is ok, I am not sure it is easy for everyone to understand. I would guess claude can write an example in 1 sec
francescoshox
left a comment
There was a problem hiding this comment.
Generally ok, once we address the comments.
Where/ how do you change the name of the codebase? I hope this is also doable. Would be bad to have discretax code inside a codebase called linax.
Tests have to be developed, like these are not good.
src/discretax/blocks/standard.py
Outdated
|
|
||
|
|
||
| class StandardBlock[ConfigType: StandardBlockConfig](Block): | ||
| class StandardBlock(Block): |
There was a problem hiding this comment.
Do we have anything different from a StandardBlock? Should we add anything different?
Maybe S4D does not have a standard block
@phnazari @benedict-armstrong
There was a problem hiding this comment.
It does indeed not have the standard block. You can find the block here: https://github.com/state-spaces/s4/blob/main/models/s4/s4d.py
src/discretax/channel_mixers/glu.py
Outdated
| key: PRNGKeyArray, | ||
| *, | ||
| out_features: int | None = None, | ||
| out_features: Cfg[int | None] = None, |
There was a problem hiding this comment.
Is this correct? Or is it Cfg[int] | None?
|
|
||
| Args: | ||
| key: JAX random key for initialization. | ||
| out_features: output dimensionality (embedding dimension). |
There was a problem hiding this comment.
In LinearEncoder we call this hidden_dim, that I would prefer also here.
| cfg: ConfigType, | ||
| key: PRNGKeyArray, | ||
| *, | ||
| reduce: Cfg[bool] = True, |
There was a problem hiding this comment.
Do we need reduce to be Cfg type such that the mixin can initialize correctly, or for any other reason?
| from discretax.models.lru import LRU | ||
| from discretax.models.s5 import S5 | ||
|
|
||
| __all__ = [ |
There was a problem hiding this comment.
Hi, we miss S4D, why is so?
What do we need to introduce S4D?
@phnazari @benedict-armstrong
src/discretax/models/lru.py
Outdated
| state_dim: Cfg[int] = 64, | ||
| r_min: Cfg[float] = 0.0, | ||
| r_max: Cfg[float] = 1.0, | ||
| max_phase: Cfg[float] = 6.28, |
There was a problem hiding this comment.
Do we need to hardcode 2 pi?
src/discretax/models/s5.py
Outdated
| from discretax.utils.config_mixin import Cfg, PartialLoaderMixin | ||
|
|
||
|
|
||
| class S5(eqx.nn.StatefulLayer, PartialLoaderMixin): |
There was a problem hiding this comment.
If S4D is a subclass of S5, we should probably also define an S4D child class.
Moreover, S4D has probably a different block or something.
| @@ -77,56 +42,68 @@ class LinOSSSequenceMixer[ConfigType: LinOSSSequenceMixerConfig](eqx.Module): | |||
| C: jax.Array | |||
| D: jax.Array | |||
| steps: jax.Array | |||
src/discretax/sequence_mixers/s5.py
Outdated
| key: PRNGKeyArray, | ||
| *, | ||
| state_dim: Cfg[int] = 64, | ||
| ssm_blocks: Cfg[int] = 1, |
There was a problem hiding this comment.
What does this refer to? ssm_blocks
phnazari
left a comment
There was a problem hiding this comment.
just reviewed some parts. will approve. But comments have to be resolved still.
| # Contributing to Linax | ||
| # Contributing to Discretax | ||
|
|
||
| Thank you for your interest in contributing to Linax! 🎉 |
There was a problem hiding this comment.
We should definitely remove most emojis und un-slop the code.
src/discretax/blocks/standard.py
Outdated
|
|
||
|
|
||
| class StandardBlock[ConfigType: StandardBlockConfig](Block): | ||
| class StandardBlock(Block): |
There was a problem hiding this comment.
It does indeed not have the standard block. You can find the block here: https://github.com/state-spaces/s4/blob/main/models/s4/s4d.py
| from discretax.channel_mixers.base import ChannelMixer | ||
|
|
||
|
|
||
| class IdentityChannelMixer(ChannelMixer): |
There was a problem hiding this comment.
@benedict-armstrong does this comment make sense?
src/discretax/encoder/embedding.py
Outdated
| from discretax.utils.config_mixin import Cfg | ||
|
|
||
|
|
||
| class EmbeddingEncoder(Encoder): |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 85 out of 91 changed files in this pull request and generated no new comments.
Comments suppressed due to low confidence (2)
src/discretax/models/linoss.py:1
- The default value uses
jnp.pi(3.14159265359) but in the old code it was alsojnp.pi. However, this differs from the hardcoded value of3.14159265359shown in line 59 of the new file. While functionally equivalent, usingjnp.piis clearer and more maintainable than a hardcoded float value.
mkdocs.yml:1 - The strict mode has been changed from
truetofalse. This setting controls whether MkDocs treats warnings as errors during the build process. Disabling strict mode may allow documentation build issues to go unnoticed. Consider keeping strict mode enabled to catch documentation problems early, or document why it needs to be disabled.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
francescoshox
left a comment
There was a problem hiding this comment.
Hi, I didn t check again models and sequence_mixer as I think there are no relevant changes from what I reviewed before. Please tell me if I should also take a look at them.
Just some general comments:
- bias is usually false except for one time, is that correct?
- the printing logic should be connected to the model directly.
Everything else is good
There was a problem hiding this comment.
examples is already in docs. Having 2 copies is difficult to maintain.
As far as I understand from @phnazari we need it in docs but happy to discuss.
| _RESET = "\033[0m" | ||
|
|
||
|
|
||
| def _colorize(text: str, color: str | None) -> str: |
There was a problem hiding this comment.
There is no docstring.
| return f"{_COLORS[color]}{text}{_RESET}" | ||
|
|
||
|
|
||
| def _print_tree( |
There was a problem hiding this comment.
We miss both typing and docstring, why ruff is not catching it?
| print(f"\nTotal: {total:,}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": |
There was a problem hiding this comment.
Shouldn t we add print_param_tree to the __repr__ or something?
src/discretax/utils/param_count.py
Outdated
| import equinox as eqx | ||
| import jax | ||
|
|
||
| from discretax.models import LinOSS |
There was a problem hiding this comment.
This should be imported in the call
if __name__ == main()
otherwise we will have a circular import the moment we put the printing in __repr__.
src/discretax/channel_mixers/base.py
Outdated
| **kwargs, | ||
| ): | ||
| """Initialize the channel mixer.""" | ||
| raise NotImplementedError("Subclasses must implement __init__") |
There was a problem hiding this comment.
This will never be triggered I think.
If someone instantiate a child class without an __init__, Python will just throw an error like "Can t instantiate class with abstract method."
| """Initialize the channel mixer.""" | ||
| raise NotImplementedError("Subclasses must implement __init__") | ||
|
|
||
| # TODO: right now we are not using this lambda. But we should! Also return is_inexact_array. |
There was a problem hiding this comment.
What do we do with this TODO?
| *, | ||
| out_features: int | None = None, | ||
| non_linearity: activation = "gelu", | ||
| use_bias: bool = False, |
There was a problem hiding this comment.
Here the bias is false, while in GLU is true, is everythig correct?
Maybe we should decide a convention, like true everywhere
| out_features: int | None = None, | ||
| hidden_ratio: int | float | None = None, | ||
| intermediate_dim: int | None = None, | ||
| use_bias: bool = False, |
There was a problem hiding this comment.
Ok sorry, maybe it makes sense sometimes is True and sometimes is False. Not running the code usually, I am not too sure about it
src/discretax/encoder/base.py
Outdated
| **kwargs, | ||
| ): | ||
| """Initialize the encoder.""" | ||
| raise NotImplementedError("Subclasses must implement __init__") |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 90 out of 94 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (7)
src/discretax/utils/param_count.py:1
class_namesuses a mutable default ({}), which can leak state between calls. UseNoneas the default and initialize an empty dict inside the function when needed.
src/discretax/utils/from_config.py:1_resolve_targetimports everydiscretax.*submodule to find a class whentargethas no module path. This can be slow and can trigger side effects from imports. Consider requiring fully-qualified targets, or adding a caching layer (e.g., memoize resolved targets), or maintaining an explicit registry of exposed types.
src/discretax/utils/from_config.py:1- The return annotation says
PartialModule, but the function can returnPartial, arbitrary nesteddicts, or leaf values. This mismatch will confuse type checkers and API consumers. Either adjust the return type (e.g.,Any/ aUnion[...]) or change behavior so it always returns aPartial/PartialModulewrapper.
src/discretax/utils/from_config.py:1 - New config-driven construction (
build_from_dict+_resolve_target) is core API surface and currently isn’t covered by tests. Add pytest coverage for: (1) fully-qualified targets, (2) short-name targets, (3) nested dict composition, and (4) failure mode/error message when a class can’t be resolved.
mkdocs.yml:1 - Disabling
strictwill hide broken mkdocstrings targets / missing pages and allow docs to build with silent errors. Once the API doc directives are updated to the newAbstract*symbols, re-enablestrict: trueto keep docs correctness enforced in CI.
src/discretax/blocks/s4d.py:1 - A placeholder module in a shipped package can confuse users and API docs tooling (it may be imported/discovered). Consider either implementing the S4D block, removing this module until ready, or adding a proper docstring explaining that it’s intentionally stubbed and not part of the public API.
docs/examples:1 - This appears to introduce a symlink-like entry from
docs/examplesto../examples. Symlinks can be brittle across platforms (notably Windows) and some tooling/packagers. If the intent is to reuse examples in docs, consider configuring MkDocs to include the external directory, or copying/duplicating only the necessary example artifacts during docs build.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| --- | ||
|
|
||
| ::: linax.sequence_mixers.base.SequenceMixer | ||
| ::: discretax.sequence_mixers.base.SequenceMixer |
There was a problem hiding this comment.
This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractSequenceMixer (per discretax.sequence_mixers.base). Update the directive to ::: discretax.sequence_mixers.base.AbstractSequenceMixer so docs generation doesn’t fail.
| ::: discretax.sequence_mixers.base.SequenceMixer | |
| ::: discretax.sequence_mixers.base.AbstractSequenceMixer |
| --- | ||
|
|
||
| ::: linax.channel_mixers.base.ChannelMixer | ||
| ::: discretax.channel_mixers.base.ChannelMixer |
There was a problem hiding this comment.
This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractChannelMixer. Update the directive to ::: discretax.channel_mixers.base.AbstractChannelMixer.
| ::: discretax.channel_mixers.base.ChannelMixer | |
| ::: discretax.channel_mixers.base.AbstractChannelMixer |
| --- | ||
|
|
||
| ::: linax.blocks.base.Block | ||
| ::: discretax.blocks.base.Block |
There was a problem hiding this comment.
This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractBlock. Update the directive to ::: discretax.blocks.base.AbstractBlock.
| ::: discretax.blocks.base.Block | |
| ::: discretax.blocks.base.AbstractBlock |
| --- | ||
|
|
||
| ::: linax.encoder.base.Encoder | ||
| ::: discretax.encoder.base.Encoder |
There was a problem hiding this comment.
This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractEncoder. Update the directive to ::: discretax.encoder.base.AbstractEncoder.
| ::: discretax.encoder.base.Encoder | |
| ::: discretax.encoder.base.AbstractEncoder |
| --- | ||
|
|
||
| ::: linax.heads.base.Head | ||
| ::: discretax.heads.base.Head |
There was a problem hiding this comment.
This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractHead. Update the directive to ::: discretax.heads.base.AbstractHead.
| ::: discretax.heads.base.Head | |
| ::: discretax.heads.base.AbstractHead |
Description
Major refactoring. Removed configs to make models simpler to use.
Type of Change
Changes Made
eqx.nn.sequential__init__can be marked withCfgwhich makes them initializable through hydra or config system