Skip to content

Ben/refactor#64

Closed
benedict-armstrong wants to merge 8 commits intomainfrom
ben/refactor
Closed

Ben/refactor#64
benedict-armstrong wants to merge 8 commits intomainfrom
ben/refactor

Conversation

@benedict-armstrong
Copy link
Collaborator

@benedict-armstrong benedict-armstrong commented Dec 5, 2025

Description

Major refactoring. Removed configs to make models simpler to use.

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📚 Documentation
  • Other

Changes Made

  • Removed all configs
  • Models designed to be used with eqx.nn.sequential
  • Arguments to __init__ can be marked with Cfg which makes them initializable through hydra or config system
  • Adapted test and example notebook

Copilot AI review requested due to automatic review settings December 5, 2025 11:02
@gemini-code-assist
Copy link

Important

Installation incomplete: to start using Gemini Code Assist, please ask the organization owner(s) to visit the Gemini Code Assist Admin Console and sign the Terms of Services.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This is a major refactoring PR that renames the project from "linax" to "discretax" and removes the configuration-based architecture in favor of direct instantiation with Cfg annotations for config system compatibility.

Key Changes:

  • Project renamed from "linax" to "discretax" across all files
  • Removed all *Config classes and their build() methods
  • Introduced Cfg type annotation for parameters that can be initialized via config systems (e.g., Hydra)
  • Models now compose with eqx.nn.Sequential instead of being monolithic
  • Tests updated to reflect new direct instantiation pattern

Reviewed changes

Copilot reviewed 80 out of 84 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
uv.lock Package name changed from "linax" to "discretax"
pyproject.toml Project metadata updated with new name and URLs
README.md Documentation updated with new project name and examples
mkdocs.yml Documentation configuration updated, strict mode set to false
src/discretax/utils/config_mixin.py New config system with Cfg annotation and PartialLoaderMixin
src/discretax/models/*.py Models refactored to direct instantiation, compose with Sequential
src/discretax/**/base.py Base classes updated, removed config classes, added PartialLoaderMixin
src/discretax/**/*.py All components updated to new pattern without configs
tests/*.py Tests updated to use direct instantiation instead of configs
docs/ All API documentation updated to reflect new structure
docs/examples/ Example notebook updated to show new usage patterns
Comments suppressed due to low confidence (10)

src/discretax/heads/classification.py:14


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

from discretax.utils.config_mixin import Cfg


class EmbeddingEncoder(Encoder):
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call Encoder.init during initialization. (EmbeddingEncoder.init may be missing a call to a base class init)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about this?

from discretax.sequence_mixers.base import SequenceMixer


class IdentitySequenceMixer(SequenceMixer):
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call SequenceMixer.init during initialization. (IdentitySequenceMixer.init may be missing a call to a base class init)

Copilot uses AI. Check for mistakes.
from discretax.channel_mixers.base import ChannelMixer


class IdentityChannelMixer(ChannelMixer):
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call ChannelMixer.init during initialization. (IdentityChannelMixer.init may be missing a call to a base class init)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benedict-armstrong does this comment make sense?

from discretax.utils.config_mixin import Cfg


class LinearEncoder(Encoder):
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call Encoder.init during initialization. (LinearEncoder.init may be missing a call to a base class init)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@francescoshox francescoshox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jus checked the doc for now sorry

assets/logo.png Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phnazari we would need to change logo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I will do that once I am back in Tüb



def register(name: str):
"""Register a class in the registry."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is nitpick, bu we should say what the args are.



def build_from_config(cfg: Any) -> Any:
"""Recursively transforms a config dict into a generic Class Factory (partial)."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for args here. Given this is a library i think it is always good practice saying

Args:
    cfg: general something config

"""Mixin to load a class from a config."""

@classmethod
def from_config(cls, cfg: dict | Any) -> functools.partial:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the output type correct?
Usually partial returns a specifc object partially instantiated

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output type can be a callable returning an instanceof a class.

"""Mixin to load a class from a config."""

@classmethod
def from_config(cls, cfg: dict | Any) -> functools.partial:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict | Any i don t like too much. Any already includes dict.



class PartialLoaderMixin:
"""Mixin to load a class from a config."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here we can add a small example of what this does. To me is ok, I am not sure it is easy for everyone to understand. I would guess claude can write an example in 1 sec

Copy link
Collaborator

@francescoshox francescoshox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally ok, once we address the comments.

Where/ how do you change the name of the codebase? I hope this is also doable. Would be bad to have discretax code inside a codebase called linax.

Tests have to be developed, like these are not good.



class StandardBlock[ConfigType: StandardBlockConfig](Block):
class StandardBlock(Block):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have anything different from a StandardBlock? Should we add anything different?
Maybe S4D does not have a standard block
@phnazari @benedict-armstrong

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does indeed not have the standard block. You can find the block here: https://github.com/state-spaces/s4/blob/main/models/s4/s4d.py

key: PRNGKeyArray,
*,
out_features: int | None = None,
out_features: Cfg[int | None] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Or is it Cfg[int] | None?


Args:
key: JAX random key for initialization.
out_features: output dimensionality (embedding dimension).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In LinearEncoder we call this hidden_dim, that I would prefer also here.

cfg: ConfigType,
key: PRNGKeyArray,
*,
reduce: Cfg[bool] = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need reduce to be Cfg type such that the mixin can initialize correctly, or for any other reason?

from discretax.models.lru import LRU
from discretax.models.s5 import S5

__all__ = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, we miss S4D, why is so?
What do we need to introduce S4D?
@phnazari @benedict-armstrong

state_dim: Cfg[int] = 64,
r_min: Cfg[float] = 0.0,
r_max: Cfg[float] = 1.0,
max_phase: Cfg[float] = 6.28,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to hardcode 2 pi?

from discretax.utils.config_mixin import Cfg, PartialLoaderMixin


class S5(eqx.nn.StatefulLayer, PartialLoaderMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If S4D is a subclass of S5, we should probably also define an S4D child class.
Moreover, S4D has probably a different block or something.

@@ -77,56 +42,68 @@ class LinOSSSequenceMixer[ConfigType: LinOSSSequenceMixerConfig](eqx.Module):
C: jax.Array
D: jax.Array
steps: jax.Array
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is steps learnable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning

key: PRNGKeyArray,
*,
state_dim: Cfg[int] = 64,
ssm_blocks: Cfg[int] = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this refer to? ssm_blocks

Copy link
Collaborator

@phnazari phnazari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just reviewed some parts. will approve. But comments have to be resolved still.

# Contributing to Linax
# Contributing to Discretax

Thank you for your interest in contributing to Linax! 🎉
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely remove most emojis und un-slop the code.



class StandardBlock[ConfigType: StandardBlockConfig](Block):
class StandardBlock(Block):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does indeed not have the standard block. You can find the block here: https://github.com/state-spaces/s4/blob/main/models/s4/s4d.py

from discretax.channel_mixers.base import ChannelMixer


class IdentityChannelMixer(ChannelMixer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benedict-armstrong does this comment make sense?

from discretax.utils.config_mixin import Cfg


class EmbeddingEncoder(Encoder):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about this?

Copilot AI review requested due to automatic review settings December 10, 2025 21:25
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 85 out of 91 changed files in this pull request and generated no new comments.

Comments suppressed due to low confidence (2)

src/discretax/models/linoss.py:1

  • The default value uses jnp.pi (3.14159265359) but in the old code it was also jnp.pi. However, this differs from the hardcoded value of 3.14159265359 shown in line 59 of the new file. While functionally equivalent, using jnp.pi is clearer and more maintainable than a hardcoded float value.
    mkdocs.yml:1
  • The strict mode has been changed from true to false. This setting controls whether MkDocs treats warnings as errors during the build process. Disabling strict mode may allow documentation build issues to go unnoticed. Consider keeping strict mode enabled to catch documentation problems early, or document why it needs to be disabled.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Collaborator

@francescoshox francescoshox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I didn t check again models and sequence_mixer as I think there are no relevant changes from what I reviewed before. Please tell me if I should also take a look at them.

Just some general comments:

  • bias is usually false except for one time, is that correct?
  • the printing logic should be connected to the model directly.

Everything else is good

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examples is already in docs. Having 2 copies is difficult to maintain.
As far as I understand from @phnazari we need it in docs but happy to discuss.

_RESET = "\033[0m"


def _colorize(text: str, color: str | None) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no docstring.

return f"{_COLORS[color]}{text}{_RESET}"


def _print_tree(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We miss both typing and docstring, why ruff is not catching it?

print(f"\nTotal: {total:,}")


if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn t we add print_param_tree to the __repr__ or something?

import equinox as eqx
import jax

from discretax.models import LinOSS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be imported in the call

if __name__ == main()

otherwise we will have a circular import the moment we put the printing in __repr__.

**kwargs,
):
"""Initialize the channel mixer."""
raise NotImplementedError("Subclasses must implement __init__")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will never be triggered I think.
If someone instantiate a child class without an __init__, Python will just throw an error like "Can t instantiate class with abstract method."

"""Initialize the channel mixer."""
raise NotImplementedError("Subclasses must implement __init__")

# TODO: right now we are not using this lambda. But we should! Also return is_inexact_array.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we do with this TODO?

*,
out_features: int | None = None,
non_linearity: activation = "gelu",
use_bias: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the bias is false, while in GLU is true, is everythig correct?
Maybe we should decide a convention, like true everywhere

out_features: int | None = None,
hidden_ratio: int | float | None = None,
intermediate_dim: int | None = None,
use_bias: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sorry, maybe it makes sense sometimes is True and sometimes is False. Not running the code usually, I am not too sure about it

**kwargs,
):
"""Initialize the encoder."""
raise NotImplementedError("Subclasses must implement __init__")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before

Copilot AI review requested due to automatic review settings February 18, 2026 09:12
@phnazari phnazari closed this Feb 18, 2026
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 90 out of 94 changed files in this pull request and generated 5 comments.

Comments suppressed due to low confidence (7)

src/discretax/utils/param_count.py:1

  • class_names uses a mutable default ({}), which can leak state between calls. Use None as the default and initialize an empty dict inside the function when needed.
    src/discretax/utils/from_config.py:1
  • _resolve_target imports every discretax.* submodule to find a class when target has no module path. This can be slow and can trigger side effects from imports. Consider requiring fully-qualified targets, or adding a caching layer (e.g., memoize resolved targets), or maintaining an explicit registry of exposed types.
    src/discretax/utils/from_config.py:1
  • The return annotation says PartialModule, but the function can return Partial, arbitrary nested dicts, or leaf values. This mismatch will confuse type checkers and API consumers. Either adjust the return type (e.g., Any / a Union[...]) or change behavior so it always returns a Partial/PartialModule wrapper.
    src/discretax/utils/from_config.py:1
  • New config-driven construction (build_from_dict + _resolve_target) is core API surface and currently isn’t covered by tests. Add pytest coverage for: (1) fully-qualified targets, (2) short-name targets, (3) nested dict composition, and (4) failure mode/error message when a class can’t be resolved.
    mkdocs.yml:1
  • Disabling strict will hide broken mkdocstrings targets / missing pages and allow docs to build with silent errors. Once the API doc directives are updated to the new Abstract* symbols, re-enable strict: true to keep docs correctness enforced in CI.
    src/discretax/blocks/s4d.py:1
  • A placeholder module in a shipped package can confuse users and API docs tooling (it may be imported/discovered). Consider either implementing the S4D block, removing this module until ready, or adding a proper docstring explaining that it’s intentionally stubbed and not part of the public API.
    docs/examples:1
  • This appears to introduce a symlink-like entry from docs/examples to ../examples. Symlinks can be brittle across platforms (notably Windows) and some tooling/packagers. If the intent is to reuse examples in docs, consider configuring MkDocs to include the external directory, or copying/duplicating only the necessary example artifacts during docs build.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

---

::: linax.sequence_mixers.base.SequenceMixer
::: discretax.sequence_mixers.base.SequenceMixer
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractSequenceMixer (per discretax.sequence_mixers.base). Update the directive to ::: discretax.sequence_mixers.base.AbstractSequenceMixer so docs generation doesn’t fail.

Suggested change
::: discretax.sequence_mixers.base.SequenceMixer
::: discretax.sequence_mixers.base.AbstractSequenceMixer

Copilot uses AI. Check for mistakes.
---

::: linax.channel_mixers.base.ChannelMixer
::: discretax.channel_mixers.base.ChannelMixer
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractChannelMixer. Update the directive to ::: discretax.channel_mixers.base.AbstractChannelMixer.

Suggested change
::: discretax.channel_mixers.base.ChannelMixer
::: discretax.channel_mixers.base.AbstractChannelMixer

Copilot uses AI. Check for mistakes.
---

::: linax.blocks.base.Block
::: discretax.blocks.base.Block
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractBlock. Update the directive to ::: discretax.blocks.base.AbstractBlock.

Suggested change
::: discretax.blocks.base.Block
::: discretax.blocks.base.AbstractBlock

Copilot uses AI. Check for mistakes.
---

::: linax.encoder.base.Encoder
::: discretax.encoder.base.Encoder
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractEncoder. Update the directive to ::: discretax.encoder.base.AbstractEncoder.

Suggested change
::: discretax.encoder.base.Encoder
::: discretax.encoder.base.AbstractEncoder

Copilot uses AI. Check for mistakes.
---

::: linax.heads.base.Head
::: discretax.heads.base.Head
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mkdocstrings target appears to reference a non-existent symbol after the refactor: the base class is AbstractHead. Update the directive to ::: discretax.heads.base.AbstractHead.

Suggested change
::: discretax.heads.base.Head
::: discretax.heads.base.AbstractHead

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants