Skip to content

Add LinOSS model(ICLR 2025 oral)#749

Open
Phoenix8215 wants to merge 9 commits intofla-org:mainfrom
Phoenix8215:linOSS
Open

Add LinOSS model(ICLR 2025 oral)#749
Phoenix8215 wants to merge 9 commits intofla-org:mainfrom
Phoenix8215:linOSS

Conversation

@Phoenix8215
Copy link

@Phoenix8215 Phoenix8215 commented Feb 17, 2026

Add support for LinOSS (Linear Ordinary State Space) model (ICLR 2025 Oral), a second-order state space model that discretizes a damped oscillator ODE using implicit midpoint (IM) or implicit-explicit (IMEX) methods.
Reference implementation: https://github.com/tk-rusch/linoss
image
image

Summary by CodeRabbit

  • New Features

    • Added the LinOSS model family (config, base model, and causal LM) with configurable attention, improved caching/state handling, and optional short-convolution path.
    • Added high-performance fused LinOSS recurrence and a chunked/naive fallback for efficient sequence processing; exposed these ops in the public API.
  • Tests

    • Added comprehensive tests covering forward/backward correctness, chunked and recurrent paths, caching/state behavior, discretizations, and multiple dtypes.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 17, 2026

Walkthrough

Adds LinOSS: new LinOSSAttention layer, LinOSS model family (LinOSSConfig, LinOSSModel, LinOSSForCausalLM), Triton-backed fused recurrent kernels with autograd plus a pure-PyTorch naive reference, chunked ops, exports/registrations across public APIs, and tests validating fused vs naive implementations.

Changes

Cohort / File(s) Summary
Public API & Registration
fla/__init__.py, fla/layers/__init__.py, fla/models/__init__.py, fla/models/linoss/__init__.py, fla/ops/__init__.py, fla/ops/linoss/__init__.py
Expose and register LinOSS symbols (LinOSSAttention, LinOSSConfig, LinOSSModel, LinOSSForCausalLM) and ops (chunk_linoss, fused_recurrent_linoss); update __all__ and register with HF AutoConfig/AutoModel/AutoModelForCausalLM.
Attention Implementation
fla/layers/linoss.py
New LinOSSAttention class: SSM params (A_diag, B, C, D, dt), modes (fused_recurrent/chunk), optional ShortConvolution input path, caching (past_key_values), gating/RMSNorm and output projection, parameter init and state sizing.
Model Config & Architecture
fla/models/linoss/configuration_linoss.py, fla/models/linoss/modeling_linoss.py
New LinOSSConfig with many hyperparameters/validations; LinOSSBlock, LinOSSPreTrainedModel, LinOSSModel, LinOSSForCausalLM including generation, fused loss options, caching, checkpoint support, and embedding/decoder accessors.
Fused Ops (Triton + Autograd)
fla/ops/linoss/fused_recurrent.py, fla/ops/__init__.py, fla/ops/linoss/__init__.py
Adds Triton kernel wrapper and high-level fused API, FusedRecurrentLinOSSFunction with forward/backward, PyTorch fallback, and non-grad entrypoint; expand ops exports.
Chunked Ops
fla/ops/linoss/chunk.py
Adds chunk_linoss: chunked/parallel LinOSS forward with IM/IMEX discretizations, padding/reshaping, per-chunk state propagation and optional final-state return.
Naive Reference
fla/ops/linoss/naive.py
Adds naive_recurrent_linoss: pure-PyTorch reference recurrence (IM/IMEX), supports initial/final states for validation and gradients.
Tests
tests/ops/test_linoss.py
Adds tests comparing fused vs naive implementations for forward/backward, initial/final state handling, chunked variants, discretizations, and dtypes.

Sequence Diagram

sequenceDiagram
    participant User as User
    participant Layer as LinOSSAttention
    participant Ops as fused_recurrent_linoss
    participant Cache as PastKeyValues/Cache
    participant Output as Model Output

    User->>Layer: forward(hidden_states, attention_mask, past_key_values?, use_cache)
    activate Layer

    Layer->>Cache: load layer state (layer_idx)
    Cache-->>Layer: recurrent_state, conv_state, offset

    Layer->>Layer: compute projection (i_proj or ShortConvolution)
    Layer->>Layer: apply mask, gating, RMSNorm

    Layer->>Ops: fused_recurrent_linoss(i, A_diag, B, C, D, dt, discretization)
    activate Ops
    Ops->>Ops: run Triton kernel or PyTorch fallback (forward/backward)
    Ops-->>Layer: outputs, final_recurrent_state
    deactivate Ops

    Layer->>Cache: update recurrent_state, conv_state, offset
    Layer->>Layer: apply output projection
    Layer-->>Output: return (o, None, past_key_values)
    deactivate Layer
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

🐇 I hop through states both real and deep,
Triton drums while recurrences keep,
Naive checks whisper gradients bright,
LinOSS gleams in day and night,
Rabbity cheers — kernels take flight! 🎉

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add LinOSS model(ICLR 2025 oral)' accurately describes the main changeset, which adds comprehensive LinOSS model support including architecture, operations, and configurations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Phoenix8215, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the novel LinOSS (Linear Ordinary State Space) model, a second-order state space model, into the project. This addition expands the library's capabilities by providing a new, efficient attention mechanism and a complete model implementation suitable for causal language modeling, leveraging optimized recurrent operations for performance.

Highlights

  • LinOSS Model Integration: Introduced the Linear Ordinary State Space (LinOSS) model, an ICLR 2025 oral paper, into the framework. This includes its core attention mechanism, configuration, and full modeling capabilities for causal language modeling.
  • Triton-based Fused Recurrent Operation: Implemented a highly optimized fused recurrent operation for the LinOSS model using Triton, enhancing computational efficiency for the state-space recurrence.
  • Comprehensive Model Structure: Added dedicated files for LinOSS configuration (LinOSSConfig), attention layer (LinOSSAttention), and the full model (LinOSSModel, LinOSSForCausalLM), ensuring proper integration and extensibility within the existing architecture.
  • Reference Implementation and Testing: Included a naive PyTorch reference implementation for the LinOSS recurrence and comprehensive unit tests to validate the correctness of the fused Triton kernel against the naive version, covering both forward and backward passes, and initial state handling.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • fla/init.py
    • Imported and exposed LinOSSAttention, LinOSSForCausalLM, and LinOSSModel to the top-level package.
  • fla/layers/init.py
    • Imported and exposed LinOSSAttention within the layers module.
  • fla/layers/linoss.py
    • Added the LinOSSAttention module, implementing the core attention logic for the LinOSS model, including parameter initialization and the forward pass using fused recurrent operations.
  • fla/models/init.py
    • Imported and exposed LinOSSConfig, LinOSSForCausalLM, and LinOSSModel within the models module.
  • fla/models/linoss/init.py
    • Added initialization file for the LinOSS model, registering LinOSSConfig, LinOSSModel, and LinOSSForCausalLM with Hugging Face's AutoConfig, AutoModel, and AutoModelForCausalLM.
  • fla/models/linoss/configuration_linoss.py
    • Added LinOSSConfig class, defining the model's hyperparameters and configuration options.
  • fla/models/linoss/modeling_linoss.py
    • Added LinOSSBlock, LinOSSPreTrainedModel, LinOSSModel, and LinOSSForCausalLM classes, providing the full model architecture and its components for sequence processing and causal language modeling.
  • fla/ops/init.py
    • Imported and exposed fused_recurrent_linoss within the operations module.
  • fla/ops/linoss/init.py
    • Added initialization file for LinOSS operations, exposing fused_recurrent_linoss.
  • fla/ops/linoss/fused_recurrent.py
    • Implemented the fused_recurrent_linoss function using Triton for optimized recurrent computation, including a forward kernel and an autograd function for backward pass.
  • fla/ops/linoss/naive.py
    • Added naive_recurrent_linoss, a pure PyTorch reference implementation of the LinOSS recurrence for validation purposes.
  • tests/ops/test_linoss.py
    • Added unit tests for the fused_recurrent_linoss operation, comparing its output and gradients against the naive_recurrent_linoss implementation for various configurations and initial states.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the LinOSS model, a second-order state space model. The implementation includes the LinOSSAttention layer, LinOSSModel, a custom Triton-based fused recurrent op, and corresponding tests. The changes are well-structured and follow the project's conventions.

My review has identified a few issues:

  • A critical issue in the fused_recurrent_linoss op where the logic for training and inference paths is inverted, preventing the use of the optimized Triton kernel during training.
  • A high-severity bug in LinOSSAttention where the cache update offset is incorrect.
  • A minor issue regarding an unused import.

I have provided suggestions to fix these issues. Once addressed, this will be a great addition to the library.

Comment on lines +322 to +430
def fused_recurrent_linoss(
x: torch.Tensor,
B_re: torch.Tensor,
B_im: torch.Tensor,
C_re: torch.Tensor,
C_im: torch.Tensor,
a_diag: torch.Tensor,
dt: torch.Tensor,
d_skip: torch.Tensor,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
discretization: str = 'IM',
) -> tuple[torch.Tensor, torch.Tensor | None]:
r"""
Fused recurrent implementation of LinOSS (Linear Ordinary State Space).

LinOSS models a second-order ODE system discretized with either IM or IMEX methods.
Each state dimension has a 2-component state [position, velocity].

Args:
x (torch.Tensor):
Input sequence of shape `[B, T, H]`.
B_re (torch.Tensor):
Real part of input matrix B, shape `[P, H]`.
B_im (torch.Tensor):
Imaginary part of input matrix B, shape `[P, H]`.
C_re (torch.Tensor):
Real part of output matrix C, shape `[H, P]`.
C_im (torch.Tensor):
Imaginary part of output matrix C, shape `[H, P]`.
a_diag (torch.Tensor):
Diagonal state matrix (pre-relu), shape `[P]`.
dt (torch.Tensor):
Discretization step sizes (pre-sigmoid), shape `[P]`.
d_skip (torch.Tensor):
Skip connection weights, shape `[H]`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[B, 2, P]`. Default: `None`.
output_final_state (bool):
Whether to output the final state. Default: `False`.
discretization (str):
Discretization method, either 'IM' or 'IMEX'. Default: `'IM'`.

Returns:
o (torch.Tensor):
Output of shape `[B, T, H]`.
final_state (Optional[torch.Tensor]):
Final state of shape `[B, 2, P]` if `output_final_state=True`.
"""
if x.requires_grad:
o = _linoss_recurrent_torch(
x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip, initial_state, discretization
)
final_state = None
if output_final_state:
B_t, T_t, H_t = x.shape
P_t = a_diag.shape[0]
a = torch.relu(a_diag)
step = torch.sigmoid(dt)

if discretization == 'IMEX':
M11 = torch.ones_like(a)
M12 = -step * a
M21 = step.clone()
M22 = 1.0 - step * step * a
else:
schur = 1.0 / (1.0 + step * step * a)
M11 = 1.0 - step * step * a * schur
M12 = -step * a * schur
M21 = step * schur
M22 = schur

Bu_re = torch.einsum('bth,ph->btp', x, B_re)
Bu_im = torch.einsum('bth,ph->btp', x, B_im)

with torch.no_grad():
h1_re = x.new_zeros(B_t, P_t)
h1_im = x.new_zeros(B_t, P_t)
h2_re = x.new_zeros(B_t, P_t)
h2_im = x.new_zeros(B_t, P_t)
if initial_state is not None:
h1_re = initial_state[:, 0].to(x.dtype)
h2_re = initial_state[:, 1].to(x.dtype)
for t in range(T_t):
bu_re_t = Bu_re[:, t]
bu_im_t = Bu_im[:, t]
if discretization == 'IMEX':
f1_re = bu_re_t * step.unsqueeze(0)
f1_im = bu_im_t * step.unsqueeze(0)
f2_re = bu_re_t * (step * step).unsqueeze(0)
f2_im = bu_im_t * (step * step).unsqueeze(0)
else:
f1_re = M11.unsqueeze(0) * bu_re_t * step.unsqueeze(0)
f1_im = M11.unsqueeze(0) * bu_im_t * step.unsqueeze(0)
f2_re = M21.unsqueeze(0) * bu_re_t * step.unsqueeze(0)
f2_im = M21.unsqueeze(0) * bu_im_t * step.unsqueeze(0)
h1_re_n = M11.unsqueeze(0) * h1_re + M12.unsqueeze(0) * h2_re + f1_re
h1_im_n = M11.unsqueeze(0) * h1_im + M12.unsqueeze(0) * h2_im + f1_im
h2_re_n = M21.unsqueeze(0) * h1_re + M22.unsqueeze(0) * h2_re + f2_re
h2_im_n = M21.unsqueeze(0) * h1_im + M22.unsqueeze(0) * h2_im + f2_im
h1_re, h1_im = h1_re_n, h1_im_n
h2_re, h2_im = h2_re_n, h2_im_n
final_state = torch.stack([h1_re, h2_re], dim=1)
return o, final_state
else:
return FusedRecurrentLinOSSFunction.apply(
x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip,
initial_state, output_final_state, discretization,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic in fused_recurrent_linoss appears to be inverted. When x.requires_grad is true (i.e., during training), it should use FusedRecurrentLinOSSFunction.apply to leverage the custom Triton kernel for the forward pass and the custom implementation for the backward pass. Instead, it falls back to a pure PyTorch implementation (_linoss_recurrent_torch), which will be slow and bypasses the custom backward. Conversely, during inference (x.requires_grad is false), it unnecessarily calls FusedRecurrentLinOSSFunction.apply, which has autograd overhead. It should call the forward kernel wrapper fused_recurrent_linoss_fwd directly for optimal performance.

The suggested change simplifies the function to always use the autograd.Function, which is the standard and correct pattern.

def fused_recurrent_linoss(
    x: torch.Tensor,
    B_re: torch.Tensor,
    B_im: torch.Tensor,
    C_re: torch.Tensor,
    C_im: torch.Tensor,
    a_diag: torch.Tensor,
    dt: torch.Tensor,
    d_skip: torch.Tensor,
    initial_state: torch.Tensor | None = None,
    output_final_state: bool = False,
    discretization: str = 'IM',
) -> tuple[torch.Tensor, torch.Tensor | None]:
    r"""
    Fused recurrent implementation of LinOSS (Linear Ordinary State Space).

    LinOSS models a second-order ODE system discretized with either IM or IMEX methods.
    Each state dimension has a 2-component state [position, velocity].

    Args:
        x (torch.Tensor):
            Input sequence of shape `[B, T, H]`.
        B_re (torch.Tensor):
            Real part of input matrix B, shape `[P, H]`.
        B_im (torch.Tensor):
            Imaginary part of input matrix B, shape `[P, H]`.
        C_re (torch.Tensor):
            Real part of output matrix C, shape `[H, P]`.
        C_im (torch.Tensor):
            Imaginary part of output matrix C, shape `[H, P]`.
        a_diag (torch.Tensor):
            Diagonal state matrix (pre-relu), shape `[P]`.
        dt (torch.Tensor):
            Discretization step sizes (pre-sigmoid), shape `[P]`.
        d_skip (torch.Tensor):
            Skip connection weights, shape `[H]`.
        initial_state (Optional[torch.Tensor]):
            Initial state of shape `[B, 2, P]`. Default: `None`.
        output_final_state (bool):
            Whether to output the final state. Default: `False`.
        discretization (str):
            Discretization method, either 'IM' or 'IMEX'. Default: `'IM'`.

    Returns:
        o (torch.Tensor):
            Output of shape `[B, T, H]`.
        final_state (Optional[torch.Tensor]):
            Final state of shape `[B, 2, P]` if `output_final_state=True`.
    """
    return FusedRecurrentLinOSSFunction.apply(
        x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip,
        initial_state, output_final_state, discretization,
    )

recurrent_state=recurrent_state,
conv_state=conv_state if self.use_short_conv else None,
layer_idx=self.layer_idx,
offset=i.shape[2] if len(i.shape) > 2 else 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The offset for updating the cache is calculated as i.shape[2], which corresponds to the hidden dimension (H). The offset should represent the number of new tokens, which is the sequence length. This should be i.shape[1].

Suggested change
offset=i.shape[2] if len(i.shape) > 2 else 1,
offset=i.shape[1],

import torch.nn.functional as F

from fla.modules import FusedRMSNormGated, ShortConvolution
from fla.modules.activations import swiglu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import swiglu from fla.modules.activations is unused in this file and can be removed to keep the code clean.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (5)
tests/ops/test_linoss.py (2)

54-54: Consider extracting os.environ['TRITON_F32_DEFAULT'] = 'ieee' into a shared fixture or conftest.py.

This is repeated in all three test functions. A session- or module-scoped fixture would be cleaner and avoid global env mutation in each test body.

Also applies to: 91-91, 146-146

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_linoss.py` at line 54, Extract the repeated environment
mutation os.environ['TRITON_F32_DEFAULT'] = 'ieee' into a pytest fixture (e.g.,
def triton_f32_default():) with scope="session" or "module" and set autouse=True
(or explicitly use the fixture in the tests) so tests no longer mutate globals;
in the fixture set the env var, yield, and restore the original value afterward
to avoid leaking state; update test functions to remove the in-test assignment
and rely on the fixture (fixture name: triton_f32_default).

82-124: Backward test validates _linoss_recurrent_torch, not the Triton kernel backward.

Since requires_grad_(True) is set, fused_recurrent_linoss takes the _linoss_recurrent_torch path (not FusedRecurrentLinOSSFunction.apply). This is fine since FusedRecurrentLinOSSFunction.backward itself delegates to _linoss_recurrent_torch, but it's worth noting that the Triton forward + torch backward round-trip isn't directly tested here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_linoss.py` around lines 82 - 124, The test currently calls
fused_recurrent_linoss(...) which, when inputs have requires_grad=True, routes
to the Python `_linoss_recurrent_torch` path; to exercise the Triton custom
backward use the custom autograd Function directly. In test_fused_recurrent_bwd
replace the fused_recurrent_linoss(...) call with an explicit call to
FusedRecurrentLinOSSFunction.apply(...) (passing the same inputs and
discretization) so the Triton forward/backward implementation in
FusedRecurrentLinOSSFunction is exercised.
fla/models/linoss/modeling_linoss.py (1)

270-283: Improve exception handling in generate.

Two issues flagged by Ruff:

  • Line 275: The re-raised AttributeError should chain the original exception with from exception (B904).
  • Line 283: raise exception should be bare raise to preserve the original traceback (TRY201).
Proposed fix
             if 'past_key_values' in str(exception):
                 raise AttributeError(
                     f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                     f"which is not supported for {self.__class__.__name__}. "
                     f"Try another generation strategy instead. "
                     f"For the available generation strategies, check this doc: "
                     f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies",
-                )
+                ) from exception
             else:
-                raise exception
+                raise
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/models/linoss/modeling_linoss.py` around lines 270 - 283, The generate
method's exception handling should preserve original tracebacks and chain
exceptions: in the AttributeError branch that constructs a new AttributeError,
raise it using "from exception" to chain the original (i.e., raise
AttributeError(...) from exception), and in the else branch re-raise the caught
exception with a bare "raise" instead of "raise exception" so the original
traceback is preserved; update the generate method in the class (the method
named generate) accordingly.
fla/ops/linoss/fused_recurrent.py (2)

214-239: Backward of FusedRecurrentLinOSSFunction may be unreachable from the public API.

fused_recurrent_linoss dispatches to FusedRecurrentLinOSSFunction.apply only when x.requires_grad is False (line 426). In that case, autograd will never invoke backward. The backward implementation (which re-runs the PyTorch recurrence) is effectively dead code when called exclusively through the public API.

This isn't a bug—it's a safety net if someone calls FusedRecurrentLinOSSFunction.apply directly—but worth noting for maintainability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/linoss/fused_recurrent.py` around lines 214 - 239, The backward
implementation in FusedRecurrentLinOSSFunction (method backward) is effectively
unreachable because fused_recurrent_linoss currently only dispatches to
FusedRecurrentLinOSSFunction.apply when x.requires_grad is False; change
fused_recurrent_linoss so it calls FusedRecurrentLinOSSFunction.apply when
x.requires_grad is True (or unconditionally) so that autograd can register the
Function and invoke backward, or alternatively add a clear comment and guard in
FusedRecurrentLinOSSFunction.apply to document intended direct-use-only behavior
and avoid confusion; update the dispatch logic in fused_recurrent_linoss (the
call site referencing FusedRecurrentLinOSSFunction.apply) to ensure the backward
path is reachable from the public API.

371-425: Large code duplication for output_final_state when x.requires_grad.

When gradients are needed and output_final_state=True, the entire recurrence is re-run (lines 397–424) just to extract the final hidden state. This duplicates the logic already in _linoss_recurrent_torch and is also inefficient: Bu_re/Bu_im (lines 394–395) are computed with grad-tracking enabled but only consumed inside a torch.no_grad() block, wasting memory on the autograd graph.

Consider either:

  1. Extending _linoss_recurrent_torch to optionally return final_state alongside outputs, or
  2. At minimum, moving the Bu_re/Bu_im einsums inside the no_grad block.
Option 2 – move einsums inside no_grad
         if output_final_state:
             B_t, T_t, H_t = x.shape
             P_t = a_diag.shape[0]
-            a = torch.relu(a_diag)
-            step = torch.sigmoid(dt)
-
-            if discretization == 'IMEX':
-                M11 = torch.ones_like(a)
-                M12 = -step * a
-                M21 = step.clone()
-                M22 = 1.0 - step * step * a
-            else:
-                schur = 1.0 / (1.0 + step * step * a)
-                M11 = 1.0 - step * step * a * schur
-                M12 = -step * a * schur
-                M21 = step * schur
-                M22 = schur
-
-            Bu_re = torch.einsum('bth,ph->btp', x, B_re)
-            Bu_im = torch.einsum('bth,ph->btp', x, B_im)
-
             with torch.no_grad():
+                a = torch.relu(a_diag)
+                step = torch.sigmoid(dt)
+
+                if discretization == 'IMEX':
+                    M11 = torch.ones_like(a)
+                    M12 = -step * a
+                    M21 = step.clone()
+                    M22 = 1.0 - step * step * a
+                else:
+                    schur = 1.0 / (1.0 + step * step * a)
+                    M11 = 1.0 - step * step * a * schur
+                    M12 = -step * a * schur
+                    M21 = step * schur
+                    M22 = schur
+
+                Bu_re = torch.einsum('bth,ph->btp', x, B_re)
+                Bu_im = torch.einsum('bth,ph->btp', x, B_im)
+
                 h1_re = x.new_zeros(B_t, P_t)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/linoss/fused_recurrent.py` around lines 371 - 425, The current code
recomputes the recurrence under x.requires_grad to build final_state, and
computes Bu_re/Bu_im with grad tracking but only uses them inside a
torch.no_grad() block; to fix this either (preferred) extend
_linoss_recurrent_torch to optionally return final_state alongside o (add a flag
like return_final_state and have callers use it) or (minimal) move the Bu_re and
Bu_im einsum computations into the with torch.no_grad(): block and keep the rest
of the final-state-only loop inside no_grad to avoid building an unnecessary
autograd graph; ensure you reference _linoss_recurrent_torch,
output_final_state, Bu_re, Bu_im, and the torch.no_grad() region when making the
change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/layers/__init__.py`:
- Line 68: Add a single trailing newline to the end of fla/layers/__init__.py so
the file ends with exactly one newline character; locate the file (contains the
closing list bracket "]") and ensure there is a newline after that bracket to
satisfy the end-of-file fixer and Ruff.

In `@fla/layers/linoss.py`:
- Around line 110-114: The code indexes past_key_values with self.layer_idx
which can be None; either validate layer_idx in the constructor or guard every
access. Fix by adding a check that self.layer_idx is not None (and is an int
within range) before using it in linoss.py: replace direct uses like
past_key_values[self.layer_idx] and past_key_values.update(...,
layer_idx=self.layer_idx, ...) with guarded logic (e.g., if self.layer_idx is
None: skip indexing/update or raise a clear ValueError in __init__); reference
symbols: self.layer_idx, past_key_values, and the methods where these appear to
ensure no NoneType indexing occurs.
- Around line 34-35: Update the __init__ signature of the LinOSSAttention class:
change the parameter annotation for layer_idx from "int = None" to "int | None =
None" to explicitly allow None, and change the return annotation from "->
LinOSSAttention" to "-> None" since __init__ must not declare a class return
type; update the signature in the LinOSSAttention.__init__ definition
accordingly.
- Around line 40-42: The constructor allows expand_ratio: float | None but uses
it directly; update the constructor in the Linoss class to guard against
expand_ratio being None by normalizing it to a float before use (e.g., if
expand_ratio is None: expand_ratio = 1.0), then compute self.input_dim =
int(hidden_size * expand_ratio) and set self.ssm_size = ssm_size if ssm_size is
not None else self.input_dim; reference symbols: expand_ratio, hidden_size,
self.input_dim, self.ssm_size.
- Around line 8-11: Remove the unused imports to satisfy the linter: delete the
"import torch.nn.functional as F" and the "from fla.modules.activations import
swiglu" lines in fla/layers/linoss.py while keeping the required imports
FusedRMSNormGated and ShortConvolution intact; if swiglu is actually needed
elsewhere, instead reference it where used (e.g., in any function that should
call swiglu) or add a usage comment to justify keeping the import.
- Around line 158-164: The offset is using the wrong axis: change the update in
past_key_values.update(...) so that offset is set from the sequence length
dimension of the input tensor i (use i.shape[1]) instead of the feature dim
(i.shape[2]); locate the call to past_key_values.update in this block
(references: past_key_values.update, recurrent_state, conv_state,
self.use_short_conv, self.layer_idx, offset) and replace the offset expression
with i.shape[1] to match other attention implementations.

In `@fla/models/linoss/configuration_linoss.py`:
- Around line 71-76: Fix the typo in the warning string inside the
fuse_linear_cross_entropy branch: update the message passed to warnings.warn
(refer to the fuse_linear_cross_entropy flag and the warnings.warn call in
configuration_linoss.py) so "can improves memory efficiency" becomes "can
improve memory efficiency" while preserving the rest of the warning text and
formatting.

In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 80-86: The kernel currently only loads/stores real parts into the
initial/final state (see b_h1_re, b_h2_re and the USE_INITIAL_STATE branch
loading from h0) while b_h1_im and b_h2_im remain zeroed and never persisted, so
imaginary components are lost across segments; update the USE_INITIAL_STATE
loading to also load the imaginary parts from h0 (into b_h1_im and b_h2_im) and
ensure the code path that writes final_state writes both real and imaginary
parts in the same tensor layout as initial_state; update any state
packing/unpacking functions and the final-state emit logic so
initial_state/final_state contain both h1_re/h1_im and h2_re/h2_im in the agreed
order.

In `@tests/ops/test_linoss.py`:
- Around line 165-167: The file ends without a trailing newline causing the
end-of-file-fixer CI change; simply add a single newline character at the end of
the file (after the final assert_close call involving ref_ht and tri_ht) so the
file terminates with a newline and the pipeline stops modifying it.

---

Nitpick comments:
In `@fla/models/linoss/modeling_linoss.py`:
- Around line 270-283: The generate method's exception handling should preserve
original tracebacks and chain exceptions: in the AttributeError branch that
constructs a new AttributeError, raise it using "from exception" to chain the
original (i.e., raise AttributeError(...) from exception), and in the else
branch re-raise the caught exception with a bare "raise" instead of "raise
exception" so the original traceback is preserved; update the generate method in
the class (the method named generate) accordingly.

In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 214-239: The backward implementation in
FusedRecurrentLinOSSFunction (method backward) is effectively unreachable
because fused_recurrent_linoss currently only dispatches to
FusedRecurrentLinOSSFunction.apply when x.requires_grad is False; change
fused_recurrent_linoss so it calls FusedRecurrentLinOSSFunction.apply when
x.requires_grad is True (or unconditionally) so that autograd can register the
Function and invoke backward, or alternatively add a clear comment and guard in
FusedRecurrentLinOSSFunction.apply to document intended direct-use-only behavior
and avoid confusion; update the dispatch logic in fused_recurrent_linoss (the
call site referencing FusedRecurrentLinOSSFunction.apply) to ensure the backward
path is reachable from the public API.
- Around line 371-425: The current code recomputes the recurrence under
x.requires_grad to build final_state, and computes Bu_re/Bu_im with grad
tracking but only uses them inside a torch.no_grad() block; to fix this either
(preferred) extend _linoss_recurrent_torch to optionally return final_state
alongside o (add a flag like return_final_state and have callers use it) or
(minimal) move the Bu_re and Bu_im einsum computations into the with
torch.no_grad(): block and keep the rest of the final-state-only loop inside
no_grad to avoid building an unnecessary autograd graph; ensure you reference
_linoss_recurrent_torch, output_final_state, Bu_re, Bu_im, and the
torch.no_grad() region when making the change.

In `@tests/ops/test_linoss.py`:
- Line 54: Extract the repeated environment mutation
os.environ['TRITON_F32_DEFAULT'] = 'ieee' into a pytest fixture (e.g., def
triton_f32_default():) with scope="session" or "module" and set autouse=True (or
explicitly use the fixture in the tests) so tests no longer mutate globals; in
the fixture set the env var, yield, and restore the original value afterward to
avoid leaking state; update test functions to remove the in-test assignment and
rely on the fixture (fixture name: triton_f32_default).
- Around line 82-124: The test currently calls fused_recurrent_linoss(...)
which, when inputs have requires_grad=True, routes to the Python
`_linoss_recurrent_torch` path; to exercise the Triton custom backward use the
custom autograd Function directly. In test_fused_recurrent_bwd replace the
fused_recurrent_linoss(...) call with an explicit call to
FusedRecurrentLinOSSFunction.apply(...) (passing the same inputs and
discretization) so the Triton forward/backward implementation in
FusedRecurrentLinOSSFunction is exercised.

Comment on lines +34 to +35
layer_idx: int = None,
) -> LinOSSAttention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Type annotation issues on __init__.

  1. layer_idx: int = None should be int | None = None (PEP 484 prohibits implicit Optional, also flagged by Ruff RUF013).
  2. The return type -> LinOSSAttention is incorrect for __init__; it should be -> None.
Proposed fix
-        layer_idx: int = None,
-    ) -> LinOSSAttention:
+        layer_idx: int | None = None,
+    ) -> None:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
layer_idx: int = None,
) -> LinOSSAttention:
layer_idx: int | None = None,
) -> None:
🧰 Tools
🪛 Ruff (0.15.0)

[warning] 34-34: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/linoss.py` around lines 34 - 35, Update the __init__ signature of
the LinOSSAttention class: change the parameter annotation for layer_idx from
"int = None" to "int | None = None" to explicitly allow None, and change the
return annotation from "-> LinOSSAttention" to "-> None" since __init__ must not
declare a class return type; update the signature in the
LinOSSAttention.__init__ definition accordingly.

Comment on lines +40 to +42
self.expand_ratio = expand_ratio
self.input_dim = int(hidden_size * expand_ratio)
self.ssm_size = ssm_size if ssm_size is not None else self.input_dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

expand_ratio typed as float | None but used without None guard.

expand_ratio defaults to 1. but the type signature allows None. Line 41 does int(hidden_size * expand_ratio) which will raise TypeError if None is passed explicitly. Either tighten the type to float or add a guard.

Proposed fix (guard)
         self.expand_ratio = expand_ratio
-        self.input_dim = int(hidden_size * expand_ratio)
+        self.input_dim = int(hidden_size * (expand_ratio if expand_ratio is not None else 1.0))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/linoss.py` around lines 40 - 42, The constructor allows
expand_ratio: float | None but uses it directly; update the constructor in the
Linoss class to guard against expand_ratio being None by normalizing it to a
float before use (e.g., if expand_ratio is None: expand_ratio = 1.0), then
compute self.input_dim = int(hidden_size * expand_ratio) and set self.ssm_size =
ssm_size if ssm_size is not None else self.input_dim; reference symbols:
expand_ratio, hidden_size, self.input_dim, self.ssm_size.

Comment on lines +110 to +114
mode = self.mode

last_state = None
if past_key_values is not None and len(past_key_values) > self.layer_idx:
last_state = past_key_values[self.layer_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential NoneType access if layer_idx is None.

If layer_idx is not set (defaults to None), the expression past_key_values[self.layer_idx] at line 114 and past_key_values.update(..., layer_idx=self.layer_idx, ...) at line 162 will use None as an index. In practice, the model builder always passes layer_idx, but the constructor allows None, creating a latent bug if someone instantiates the layer standalone.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/linoss.py` around lines 110 - 114, The code indexes
past_key_values with self.layer_idx which can be None; either validate layer_idx
in the constructor or guard every access. Fix by adding a check that
self.layer_idx is not None (and is an int within range) before using it in
linoss.py: replace direct uses like past_key_values[self.layer_idx] and
past_key_values.update(..., layer_idx=self.layer_idx, ...) with guarded logic
(e.g., if self.layer_idx is None: skip indexing/update or raise a clear
ValueError in __init__); reference symbols: self.layer_idx, past_key_values, and
the methods where these appear to ensure no NoneType indexing occurs.

Comment on lines +71 to +76
if fuse_linear_cross_entropy:
warnings.warn(
"`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
"at the potential cost of reduced precision. "
"If you observe issues like loss divergence, consider disabling this setting.",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Typo in warning message.

Line 73: "can improves memory efficiency""can improve memory efficiency".

Proposed fix
-                "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
+                "`fuse_linear_cross_entropy` is enabled, which can improve memory efficiency "
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if fuse_linear_cross_entropy:
warnings.warn(
"`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
"at the potential cost of reduced precision. "
"If you observe issues like loss divergence, consider disabling this setting.",
)
if fuse_linear_cross_entropy:
warnings.warn(
"`fuse_linear_cross_entropy` is enabled, which can improve memory efficiency "
"at the potential cost of reduced precision. "
"If you observe issues like loss divergence, consider disabling this setting.",
)
🧰 Tools
🪛 Ruff (0.15.0)

[warning] 72-72: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/models/linoss/configuration_linoss.py` around lines 71 - 76, Fix the typo
in the warning string inside the fuse_linear_cross_entropy branch: update the
message passed to warnings.warn (refer to the fuse_linear_cross_entropy flag and
the warnings.warn call in configuration_linoss.py) so "can improves memory
efficiency" becomes "can improve memory efficiency" while preserving the rest of
the warning text and formatting.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/ops/test_linoss.py (1)

54-54: Consider scoping TRITON_F32_DEFAULT at the module or fixture level.

Setting os.environ['TRITON_F32_DEFAULT'] inside each test function is repetitive and relies on the process-wide environment. A @pytest.fixture(autouse=True) or a module-level assignment would be cleaner and less error-prone.

Also applies to: 91-91, 146-146

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_linoss.py` at line 54, Multiple tests set
os.environ['TRITON_F32_DEFAULT'] repetitively inside test functions; replace
that with a single module-scoped solution by adding a pytest fixture with
`@pytest.fixture`(autouse=True) that sets os.environ['TRITON_F32_DEFAULT'] =
'ieee' (and restores previous value on teardown) or by assigning it once at
module import time, so remove the per-test assignments and centralize the
environment setup used by the tests referencing TRITON_F32_DEFAULT.
fla/models/linoss/modeling_linoss.py (1)

270-283: Use raise ... from for chained exceptions.

When re-raising as a different AttributeError, chain with from exception so the original traceback is preserved. Also, bare raise exception on line 283 should be just raise.

Proposed fix
-                raise AttributeError(
+                raise AttributeError(
                     f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                     f"which is not supported for {self.__class__.__name__}. "
                     f"Try another generation strategy instead. "
                     f"For the available generation strategies, check this doc: "
                     f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies",
-                )
+                ) from exception
             else:
-                raise exception
+                raise
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/models/linoss/modeling_linoss.py` around lines 270 - 283, In the generate
method replace the unchained re-raise by chaining the new AttributeError to the
original (use "raise AttributeError(...) from exception") so the original
traceback is preserved when you raise the custom message for 'past_key_values',
and change the final "raise exception" branch to a bare "raise" to re-raise the
original exception; refer to the generate method and the AttributeError handling
block in modeling_linoss.py to locate the changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/models/linoss/modeling_linoss.py`:
- Around line 320-339: The forward logic creates a local criterion when
self.criterion is None (using FusedLinearCrossEntropyLoss,
FusedCrossEntropyLoss, or nn.CrossEntropyLoss) but never caches it; update the
block where criterion is instantiated (the branch that sets criterion when
getattr(self, 'criterion', None) is None) to assign the new instance to
self.criterion (e.g., self.criterion = criterion) so subsequent calls reuse the
cached criterion instead of recreating it each forward pass; ensure subsequent
code uses self.criterion where appropriate.

---

Duplicate comments:
In `@fla/layers/linoss.py`:
- Around line 32-33: Update the LinOSSAttention.__init__ signature: change the
parameter annotation from "layer_idx: int = None" to "layer_idx: int | None =
None" to satisfy the optional-int typing (RUF013), and change the __init__
return annotation from "-> LinOSSAttention" to "-> None" so the constructor has
the correct return type; locate these in the LinOSSAttention class __init__
definition and update the type hints accordingly.
- Around line 156-162: The offset is incorrectly using the feature dim
(i.shape[2]) instead of sequence length; update the past_key_values.update call
so offset uses i.shape[1] when available (e.g. offset=i.shape[1] if len(i.shape)
> 1 else 1) while keeping the other fields (recurrent_state, conv_state when
self.use_short_conv, layer_idx) unchanged—locate the past_key_values.update
invocation in linoss.py and replace the offset expression accordingly.
- Around line 38-40: expand_ratio is annotated as float | None but used directly
in computing input_dim, which will raise TypeError if expand_ratio is None;
update the initialization in the LinOSS (or containing) class so you first check
if self.expand_ratio is None and set self.input_dim = hidden_size in that case,
otherwise compute self.input_dim = int(hidden_size * self.expand_ratio); then
compute self.ssm_size = ssm_size if ssm_size is not None else self.input_dim to
preserve the existing fallback behavior. Ensure you reference and update the
assignments to self.expand_ratio, self.input_dim, and self.ssm_size in
linoss.py.
- Around line 110-112: last_state indexing can raise if self.layer_idx is None;
update the guard around past_key_values to ensure self.layer_idx is an int and
within range before indexing (e.g., check isinstance(self.layer_idx, int) and 0
<= self.layer_idx < len(past_key_values)) so that last_state =
past_key_values[self.layer_idx] only runs when safe; modify the block
referencing past_key_values and self.layer_idx (the variables last_state,
past_key_values, and self.layer_idx) accordingly to handle None or invalid
layer_idx values.

In `@fla/models/linoss/configuration_linoss.py`:
- Line 73: Replace the typo in the user-facing string that mentions
`fuse_linear_cross_entropy`: change "can improves" to "can improve" wherever
that phrase appears (e.g., the message or docstring referencing
`fuse_linear_cross_entropy` in configuration_linoss.py) so the sentence reads
correctly.

---

Nitpick comments:
In `@fla/models/linoss/modeling_linoss.py`:
- Around line 270-283: In the generate method replace the unchained re-raise by
chaining the new AttributeError to the original (use "raise AttributeError(...)
from exception") so the original traceback is preserved when you raise the
custom message for 'past_key_values', and change the final "raise exception"
branch to a bare "raise" to re-raise the original exception; refer to the
generate method and the AttributeError handling block in modeling_linoss.py to
locate the changes.

In `@tests/ops/test_linoss.py`:
- Line 54: Multiple tests set os.environ['TRITON_F32_DEFAULT'] repetitively
inside test functions; replace that with a single module-scoped solution by
adding a pytest fixture with `@pytest.fixture`(autouse=True) that sets
os.environ['TRITON_F32_DEFAULT'] = 'ieee' (and restores previous value on
teardown) or by assigning it once at module import time, so remove the per-test
assignments and centralize the environment setup used by the tests referencing
TRITON_F32_DEFAULT.

Comment on lines +320 to +339
loss, logits = None, None
if not self.config.fuse_linear_cross_entropy or labels is None:
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
if labels is not None:
if getattr(self, 'criterion', None) is None:
if self.config.fuse_linear_cross_entropy:
criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
elif self.config.fuse_cross_entropy:
criterion = FusedCrossEntropyLoss(inplace_backward=True)
else:
criterion = nn.CrossEntropyLoss()
else:
criterion = self.criterion
labels = labels.to(hidden_states.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
if self.config.fuse_linear_cross_entropy:
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
else:
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
loss = l2_warp(loss, logits) if self.config.use_l2warp else loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat -n fla/models/linoss/modeling_linoss.py | sed -n '315,345p'

Repository: fla-org/flash-linear-attention

Length of output: 1794


🏁 Script executed:

rg -n 'self\.criterion\s*=' fla/models/ -g '*.py' -A 2 | head -50

Repository: fla-org/flash-linear-attention

Length of output: 2885


🏁 Script executed:

rg -n 'criterion\s*=' fla/models/linoss/modeling_linoss.py -B 2 -A 2

Repository: fla-org/flash-linear-attention

Length of output: 978


Add self.criterion = criterion assignment after line 330 to cache the criterion.

The criterion is instantiated locally (lines 326–330) but never assigned to self.criterion, causing it to be re-created on every forward pass with labels. Line 248 initializes self.criterion = None, and line 324 checks for it, indicating the intent to cache. This pattern matches all other models in the codebase (gated_deltaproduct, transformer, rwkv7, etc.), which properly assign the criterion after creation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/models/linoss/modeling_linoss.py` around lines 320 - 339, The forward
logic creates a local criterion when self.criterion is None (using
FusedLinearCrossEntropyLoss, FusedCrossEntropyLoss, or nn.CrossEntropyLoss) but
never caches it; update the block where criterion is instantiated (the branch
that sets criterion when getattr(self, 'criterion', None) is None) to assign the
new instance to self.criterion (e.g., self.criterion = criterion) so subsequent
calls reuse the cached criterion instead of recreating it each forward pass;
ensure subsequent code uses self.criterion where appropriate.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
fla/ops/linoss/fused_recurrent.py (2)

88-98: Inner loop over H is sequential in the kernel — expected for this architecture but worth documenting.

The for h in range(H) loops at lines 93 and 122 iterate element-by-element over the hidden dimension within each Triton program. For large H (e.g., 1024+), this may become a bottleneck since Triton unrolls these into sequential scalar operations. The P-dimension is parallelized across programs and the atomic_add at line 130 handles cross-block output accumulation.

This is a known pattern in the fla codebase's fused recurrent kernels and acceptable for now, but a brief comment noting this design choice would help future contributors.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/linoss/fused_recurrent.py` around lines 88 - 98, Add a brief in-code
comment above the inner "for h in range(H)" loop in fused_recurrent.py (the loop
that loads b_u_h, B_re/B_im and accumulates b_bu_re/b_bu_im) stating that this
inner H loop is intentionally sequential in the Triton kernel, that P-dimension
is parallelized across programs and cross-block accumulation is handled via
atomic_add, and that this design may be a bottleneck for very large H (e.g.,
1024+) but is acceptable for this architecture; reference the
variables/operations b_u_h, B_re/B_im loads, b_bu_re/b_bu_im accumulation and
the atomic_add used for output accumulation so future contributors understand
the tradeoff.

371-380: Gradient dispatch checks only x.requires_grad, which may skip autograd for parameter-only gradients.

If x is detached (e.g., in certain eval-with-grad or mixed scenarios), the else-branch calls fused_recurrent_linoss_fwd directly, bypassing FusedRecurrentLinOSSFunction.apply. This means gradients for B_re, C_re, a_diag, dt, d_skip, etc. would not be computed even if those tensors require grad.

In typical training this is fine (since x comes from a nn.Linear and carries grad), but a more robust check would be:

-    if x.requires_grad:
+    if torch.is_grad_enabled():

or check any input's requires_grad. This aligns with a more defensive dispatch pattern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/linoss/fused_recurrent.py` around lines 371 - 380, The current
dispatch only checks x.requires_grad before choosing between
FusedRecurrentLinOSSFunction.apply and fused_recurrent_linoss_fwd, which misses
cases where other tensors (B_re, B_im, C_re, C_im, a_diag, dt, d_skip,
initial_state, etc.) require gradients; change the condition to test whether any
of the relevant inputs require grad (e.g., any(t is not None and getattr(t,
"requires_grad", False) for t in (x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip,
initial_state))) and call FusedRecurrentLinOSSFunction.apply when true,
otherwise call fused_recurrent_linoss_fwd to ensure parameter-only gradients are
preserved.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 214-239: The backward currently recomputes the forward via
_linoss_recurrent_torch under torch.enable_grad() and calls o.backward(do), but
this can produce incorrect gradients if the recomputed forward uses a different
dtype/accumulation than the Triton kernel and it entirely ignores the incoming
dht (gradient w.r.t. final state). Fix by ensuring the recomputation uses the
exact same dtype and accumulation settings as the Triton kernel (cast tensors
like x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip to the original kernel's
dtype and set deterministic accumulation if needed) and by including the
final-state gradient: have the recomputed forward return both outputs and
final_state, then call backward with both gradients (o.backward((do, dht)) or
use torch.autograd.grad to request gradients w.r.t. inputs including
initial_state while passing dht for the final state), and return the resulting
grads for initial_state (and all other saved tensors). Reference symbols:
backward (staticmethod), ctx.saved_tensors, dht, _linoss_recurrent_torch,
initial_state.

---

Duplicate comments:
In `@fla/layers/linoss.py`:
- Around line 32-33: The __init__ signature for class LinOSSAttention is
incorrectly annotated: change the parameter annotation from "layer_idx: int =
None" to "layer_idx: int | None = None" (or use Optional[int]) and update the
__init__ return annotation from "-> LinOSSAttention" to "-> None"; locate and
edit the __init__ method of LinOSSAttention to apply these fixes.
- Around line 38-40: The constructor currently multiplies hidden_size by
expand_ratio without guarding against expand_ratio being None; change the logic
in the Linoss initializer so expand_ratio is normalized first (e.g.,
effective_expand = expand_ratio if expand_ratio is not None else 1.0) and then
compute self.input_dim = int(hidden_size * effective_expand); keep
self.expand_ratio set to the original value or the normalized value per your API
choice, and ensure self.ssm_size still defaults to self.input_dim when ssm_size
is None.
- Around line 110-112: The code assumes self.layer_idx is an int; guard against
None by checking it before indexing or updating past_key_values: change the
condition to check that self.layer_idx is not None and is an int within range
(e.g., if past_key_values is not None and self.layer_idx is not None and 0 <=
self.layer_idx < len(past_key_values): last_state =
past_key_values[self.layer_idx]) and similarly ensure any call that does
past_key_values.update(..., layer_idx=self.layer_idx) only passes layer_idx when
self.layer_idx is not None (or use a different update key), so references to
self.layer_idx (in methods like where last_state is set and where
past_key_values is updated) never attempt to index with None.

In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 80-86: The kernel initializes b_h1_im and b_h2_im to zeros but
only loads/stores the real components when USE_INITIAL_STATE / final-state
handling occurs, so imaginary state is lost across segments; update the
initial-state load to also load the imaginary components into b_h1_im and
b_h2_im (use the appropriate h0 offsets for the imaginary parts, mirroring the
real-part loads) and likewise update the final-state store logic to write both
real and imaginary parts; apply the same fix to the torch fallback code path
where h1_im/h2_im are currently left zero so both real and imaginary components
are persisted across segments.

---

Nitpick comments:
In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 88-98: Add a brief in-code comment above the inner "for h in
range(H)" loop in fused_recurrent.py (the loop that loads b_u_h, B_re/B_im and
accumulates b_bu_re/b_bu_im) stating that this inner H loop is intentionally
sequential in the Triton kernel, that P-dimension is parallelized across
programs and cross-block accumulation is handled via atomic_add, and that this
design may be a bottleneck for very large H (e.g., 1024+) but is acceptable for
this architecture; reference the variables/operations b_u_h, B_re/B_im loads,
b_bu_re/b_bu_im accumulation and the atomic_add used for output accumulation so
future contributors understand the tradeoff.
- Around line 371-380: The current dispatch only checks x.requires_grad before
choosing between FusedRecurrentLinOSSFunction.apply and
fused_recurrent_linoss_fwd, which misses cases where other tensors (B_re, B_im,
C_re, C_im, a_diag, dt, d_skip, initial_state, etc.) require gradients; change
the condition to test whether any of the relevant inputs require grad (e.g.,
any(t is not None and getattr(t, "requires_grad", False) for t in (x, B_re,
B_im, C_re, C_im, a_diag, dt, d_skip, initial_state))) and call
FusedRecurrentLinOSSFunction.apply when true, otherwise call
fused_recurrent_linoss_fwd to ensure parameter-only gradients are preserved.

@zhiyuan1i
Copy link
Collaborator

Hi @Phoenix8215, thanks for this great contribution bringing LinOSS to FLA! 🎉

I noticed that the current implementation primarily focuses on the fused_recurrent mode. To make this implementation more complete and aligned with other models in FLA, I would like to suggest adding the following features:
1. Chunk-wise parallel mode (chunk)
Most FLA models support a chunk mode for efficient training with long sequences. This parallelizes computation across chunks and can significantly speed up training compared to the recurrent mode.
2. Backward for naive implementation
While the fused recurrent mode has autograd support, having a PyTorch-native backward pass for the naive implementation would be helpful for debugging and gradient verification.

Would you or anyone from the community be interested in implementing these? I am happy to help review or collaborate! 🙏

@Phoenix8215
Copy link
Author

Hi @Phoenix8215, thanks for this great contribution bringing LinOSS to FLA! 🎉

I noticed that the current implementation primarily focuses on the fused_recurrent mode. To make this implementation more complete and aligned with other models in FLA, I would like to suggest adding the following features: 1. Chunk-wise parallel mode (chunk) Most FLA models support a chunk mode for efficient training with long sequences. This parallelizes computation across chunks and can significantly speed up training compared to the recurrent mode. 2. Backward for naive implementation While the fused recurrent mode has autograd support, having a PyTorch-native backward pass for the naive implementation would be helpful for debugging and gradient verification.

Would you or anyone from the community be interested in implementing these? I am happy to help review or collaborate! 🙏

Thanks for the suggestions!🥰

Both features are now implemented. Let me know if further refinements would be helpful.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
fla/ops/linoss/fused_recurrent.py (1)

268-268: H is unpacked but never used — prefix with _.

Ruff RUF059 flags this. H is not referenced in the body of _linoss_recurrent_torch; the einsum contractions are expressed with string subscripts, not via the variable.

♻️ Proposed fix
-    Bat, T, H = x.shape
+    Bat, T, _ = x.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/linoss/fused_recurrent.py` at line 268, The tuple unpacking in
_linoss_recurrent_torch currently does "Bat, T, H = x.shape" but H is never
used; rename the third variable to a throwaway name (e.g., _ or _H) to satisfy
RUF059. Update the unpacking in the _linoss_recurrent_torch function from "Bat,
T, H = x.shape" to "Bat, T, _ = x.shape" (or "_H") and ensure no other
references to H exist in that function.
fla/ops/linoss/chunk.py (1)

66-66: H is unpacked but never used — prefix with _.

Same Ruff RUF059 pattern as in _linoss_recurrent_torch: H is assigned from x.shape but not referenced anywhere in the function body.

♻️ Proposed fix
-    Bat, T, H = x.shape
+    Bat, T, _ = x.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/linoss/chunk.py` at line 66, The tuple unpacking assigns Bat, T, H =
x.shape but H is never used; update the unpacking in the same function (the line
with "Bat, T, H = x.shape" in fla.ops.linoss.chunk.py) to use a prefixed
underscore for the unused dimension (e.g., Bat, T, _H or Bat, T, _) to satisfy
the Ruff RUF059 rule and clearly mark the variable as intentionally unused;
mirror the same change pattern used in _linoss_recurrent_torch if present.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 244-248: The backward is currently returning None for
initial_state which drops gradients for a learnable or non-detached
initial_state; update the return tuple in the backward of
_linoss_recurrent_torch (the block that currently returns x.grad, B_re.grad,
B_im.grad, C_re.grad, C_im.grad, a_diag.grad, dt.grad, d_skip.grad, None, None,
None) to return the gradient for initial_state in the correct slot (replace the
first of those trailing Nones with the gradient tensor computed for
initial_state or initial_state.grad), preserving the original argument order so
callers receive the initial_state gradient when it requires_grad. Ensure you
reference the saved/rewound initial_state used in the forward so you return its
grad consistently with other saved inputs.

In `@fla/ops/linoss/naive.py`:
- Line 95: The multiplication uses d_skip with x (which is cast to float32),
causing dtype mismatches in half/bf16 contexts; update the expression that
computes y_t to cast d_skip to float (e.g., d_skip.float() or
d_skip.to(x.dtype)) before multiplying by x[:, t] so the types match—modify the
line where y_t is computed (references: y_t, d_skip, x, h2, C_complex) similar
to how _linoss_recurrent_torch and chunk_linoss explicitly cast d_skip to float.

---

Duplicate comments:
In `@fla/layers/linoss.py`:
- Around line 32-33: Update the type annotations in the LinOSSAttention class
initializer: change the parameter annotation from "layer_idx: int = None" to
"layer_idx: int | None = None" to make the optionality explicit per PEP 484, and
correct the __init__ signature return annotation from "-> LinOSSAttention" to
"-> None". Adjust only the annotations on the __init__ method (the parameter
named layer_idx and the method return) so function name __init__ and class
LinOSSAttention remain unchanged.
- Around line 38-40: expand_ratio is allowed to be None but you immediately
multiply it in the input_dim computation (self.input_dim = int(hidden_size *
expand_ratio)), causing a TypeError; guard against None by normalizing
expand_ratio before use (e.g., set self.expand_ratio = 1.0 if expand_ratio is
None) and then compute self.input_dim and self.ssm_size using the normalized
value (ensure ssm_size fallback still uses the computed self.input_dim when
ssm_size is None).
- Around line 110-112: self.layer_idx may be None which causes a TypeError when
used in comparisons and passed into past_key_values; to fix, validate and
normalize layer_idx in the class __init__ (raise a ValueError if layer_idx is
None or not an int) so self.layer_idx is always a valid integer, and then the
code that reads last_state (the block using self.layer_idx to index
past_key_values) and the call to past_key_values.update(...,
layer_idx=self.layer_idx, ...) will no longer receive None; alternatively, if
you prefer not to raise, guard the usage site by checking "if self.layer_idx is
not None and len(past_key_values) > self.layer_idx:" and only pass layer_idx
into past_key_values.update when it is not None.

---

Nitpick comments:
In `@fla/ops/linoss/chunk.py`:
- Line 66: The tuple unpacking assigns Bat, T, H = x.shape but H is never used;
update the unpacking in the same function (the line with "Bat, T, H = x.shape"
in fla.ops.linoss.chunk.py) to use a prefixed underscore for the unused
dimension (e.g., Bat, T, _H or Bat, T, _) to satisfy the Ruff RUF059 rule and
clearly mark the variable as intentionally unused; mirror the same change
pattern used in _linoss_recurrent_torch if present.

In `@fla/ops/linoss/fused_recurrent.py`:
- Line 268: The tuple unpacking in _linoss_recurrent_torch currently does "Bat,
T, H = x.shape" but H is never used; rename the third variable to a throwaway
name (e.g., _ or _H) to satisfy RUF059. Update the unpacking in the
_linoss_recurrent_torch function from "Bat, T, H = x.shape" to "Bat, T, _ =
x.shape" (or "_H") and ensure no other references to H exist in that function.

Comment on lines +244 to +248
return (
x.grad, B_re.grad, B_im.grad, C_re.grad, C_im.grad,
a_diag.grad, dt.grad, d_skip.grad,
None, None, None,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Gradient for initial_state is unconditionally None — silently wrong for learnable initial state or truncated BPTT.

initial_state is the only saved input that is neither detached nor given requires_grad_(True) before the re-run of _linoss_recurrent_torch (all other eight inputs are explicitly prepared at lines 226–233). Consequently, even if a caller passes initial_state with requires_grad=True (learnable h₀, or gradient passed back from a later segment), that gradient is silently discarded.

Standard inference and single-pass training (where initial_state is None or a detached cache tensor) are unaffected, but the issue would produce silent wrong gradients for truncated BPTT or learnable-h₀ experiments.

🛡️ Minimal fix to propagate `initial_state` gradient
     with torch.enable_grad():
         x = x.detach().requires_grad_(True)
         ...
         d_skip = d_skip.detach().requires_grad_(True)
+        if initial_state is not None:
+            initial_state = initial_state.detach().requires_grad_(True)

         o, ht = _linoss_recurrent_torch(...)
         ...
         torch.autograd.backward(targets, grads)

     return (
         x.grad, B_re.grad, B_im.grad, C_re.grad, C_im.grad,
         a_diag.grad, dt.grad, d_skip.grad,
-        None, None, None,
+        initial_state.grad if initial_state is not None else None, None, None,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/linoss/fused_recurrent.py` around lines 244 - 248, The backward is
currently returning None for initial_state which drops gradients for a learnable
or non-detached initial_state; update the return tuple in the backward of
_linoss_recurrent_torch (the block that currently returns x.grad, B_re.grad,
B_im.grad, C_re.grad, C_im.grad, a_diag.grad, dt.grad, d_skip.grad, None, None,
None) to return the gradient for initial_state in the correct slot (replace the
first of those trailing Nones with the gradient tensor computed for
initial_state or initial_state.grad), preserving the original argument order so
callers receive the initial_state gradient when it requires_grad. Ensure you
reference the saved/rewound initial_state used in the forward so you return its
grad consistently with other saved inputs.

h1 = h1_new
h2 = h2_new

y_t = (h2 @ C_complex.t()).real + d_skip.unsqueeze(0) * x[:, t]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

d_skip missing .float() cast — dtype error in half-precision contexts.

x is cast to float32 at line 42, but d_skip is used raw. With model.half() or bf16 parameters, d_skip (fp16/bf16) multiplied against x[:, t] (fp32) triggers a PyTorch dtype-mismatch error. Both _linoss_recurrent_torch (line 330 of fused_recurrent.py) and chunk_linoss (line 205 of chunk.py) explicitly call .float() on d_skip.

🐛 Proposed fix
-        y_t = (h2 @ C_complex.t()).real + d_skip.unsqueeze(0) * x[:, t]
+        y_t = (h2 @ C_complex.t()).real + d_skip.float().unsqueeze(0) * x[:, t]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
y_t = (h2 @ C_complex.t()).real + d_skip.unsqueeze(0) * x[:, t]
y_t = (h2 @ C_complex.t()).real + d_skip.float().unsqueeze(0) * x[:, t]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/linoss/naive.py` at line 95, The multiplication uses d_skip with x
(which is cast to float32), causing dtype mismatches in half/bf16 contexts;
update the expression that computes y_t to cast d_skip to float (e.g.,
d_skip.float() or d_skip.to(x.dtype)) before multiplying by x[:, t] so the types
match—modify the line where y_t is computed (references: y_t, d_skip, x, h2,
C_complex) similar to how _linoss_recurrent_torch and chunk_linoss explicitly
cast d_skip to float.

@Phoenix8215
Copy link
Author

Hi @zhiyuan1i , just checking whether you had a chance to review the recent updates. 😙😙Happy to address any comments.

@Phoenix8215
Copy link
Author

Hi @zhiyuan1i, @yzhangcs , could you please help review this PR when you have time?
All checks are passing and I’ve addressed previous comments.
Thanks a lot!

@zhiyuan1i
Copy link
Collaborator

Hi @Phoenix8215, thanks for the PR and congrats on the ICLR Oral! 🎉

I'm wondering if we could adjust the implementation approach:

fused_recurrent for inference, chunk for training - Currently the fused version has PyTorch autograd backward which doesn't fit FLA's philosophy. Maybe we can:

  • Keep used_recurrent for inference only (fast forward)
  • Use chunk implementation for training with proper backward support

Also, I might be misunderstanding this - does LinOSS not support TensorCore-parallel training? The recurrent nature makes it hard to parallelize across sequence dimension, right?

Would love to hear your thoughts on the chunk-based training approach. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants