Conversation
WalkthroughAdds LinOSS: new LinOSSAttention layer, LinOSS model family (LinOSSConfig, LinOSSModel, LinOSSForCausalLM), Triton-backed fused recurrent kernels with autograd plus a pure-PyTorch naive reference, chunked ops, exports/registrations across public APIs, and tests validating fused vs naive implementations. Changes
Sequence DiagramsequenceDiagram
participant User as User
participant Layer as LinOSSAttention
participant Ops as fused_recurrent_linoss
participant Cache as PastKeyValues/Cache
participant Output as Model Output
User->>Layer: forward(hidden_states, attention_mask, past_key_values?, use_cache)
activate Layer
Layer->>Cache: load layer state (layer_idx)
Cache-->>Layer: recurrent_state, conv_state, offset
Layer->>Layer: compute projection (i_proj or ShortConvolution)
Layer->>Layer: apply mask, gating, RMSNorm
Layer->>Ops: fused_recurrent_linoss(i, A_diag, B, C, D, dt, discretization)
activate Ops
Ops->>Ops: run Triton kernel or PyTorch fallback (forward/backward)
Ops-->>Layer: outputs, final_recurrent_state
deactivate Ops
Layer->>Cache: update recurrent_state, conv_state, offset
Layer->>Layer: apply output projection
Layer-->>Output: return (o, None, past_key_values)
deactivate Layer
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Phoenix8215, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates the novel LinOSS (Linear Ordinary State Space) model, a second-order state space model, into the project. This addition expands the library's capabilities by providing a new, efficient attention mechanism and a complete model implementation suitable for causal language modeling, leveraging optimized recurrent operations for performance. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces the LinOSS model, a second-order state space model. The implementation includes the LinOSSAttention layer, LinOSSModel, a custom Triton-based fused recurrent op, and corresponding tests. The changes are well-structured and follow the project's conventions.
My review has identified a few issues:
- A critical issue in the
fused_recurrent_linossop where the logic for training and inference paths is inverted, preventing the use of the optimized Triton kernel during training. - A high-severity bug in
LinOSSAttentionwhere the cache update offset is incorrect. - A minor issue regarding an unused import.
I have provided suggestions to fix these issues. Once addressed, this will be a great addition to the library.
| def fused_recurrent_linoss( | ||
| x: torch.Tensor, | ||
| B_re: torch.Tensor, | ||
| B_im: torch.Tensor, | ||
| C_re: torch.Tensor, | ||
| C_im: torch.Tensor, | ||
| a_diag: torch.Tensor, | ||
| dt: torch.Tensor, | ||
| d_skip: torch.Tensor, | ||
| initial_state: torch.Tensor | None = None, | ||
| output_final_state: bool = False, | ||
| discretization: str = 'IM', | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | ||
| r""" | ||
| Fused recurrent implementation of LinOSS (Linear Ordinary State Space). | ||
|
|
||
| LinOSS models a second-order ODE system discretized with either IM or IMEX methods. | ||
| Each state dimension has a 2-component state [position, velocity]. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): | ||
| Input sequence of shape `[B, T, H]`. | ||
| B_re (torch.Tensor): | ||
| Real part of input matrix B, shape `[P, H]`. | ||
| B_im (torch.Tensor): | ||
| Imaginary part of input matrix B, shape `[P, H]`. | ||
| C_re (torch.Tensor): | ||
| Real part of output matrix C, shape `[H, P]`. | ||
| C_im (torch.Tensor): | ||
| Imaginary part of output matrix C, shape `[H, P]`. | ||
| a_diag (torch.Tensor): | ||
| Diagonal state matrix (pre-relu), shape `[P]`. | ||
| dt (torch.Tensor): | ||
| Discretization step sizes (pre-sigmoid), shape `[P]`. | ||
| d_skip (torch.Tensor): | ||
| Skip connection weights, shape `[H]`. | ||
| initial_state (Optional[torch.Tensor]): | ||
| Initial state of shape `[B, 2, P]`. Default: `None`. | ||
| output_final_state (bool): | ||
| Whether to output the final state. Default: `False`. | ||
| discretization (str): | ||
| Discretization method, either 'IM' or 'IMEX'. Default: `'IM'`. | ||
|
|
||
| Returns: | ||
| o (torch.Tensor): | ||
| Output of shape `[B, T, H]`. | ||
| final_state (Optional[torch.Tensor]): | ||
| Final state of shape `[B, 2, P]` if `output_final_state=True`. | ||
| """ | ||
| if x.requires_grad: | ||
| o = _linoss_recurrent_torch( | ||
| x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip, initial_state, discretization | ||
| ) | ||
| final_state = None | ||
| if output_final_state: | ||
| B_t, T_t, H_t = x.shape | ||
| P_t = a_diag.shape[0] | ||
| a = torch.relu(a_diag) | ||
| step = torch.sigmoid(dt) | ||
|
|
||
| if discretization == 'IMEX': | ||
| M11 = torch.ones_like(a) | ||
| M12 = -step * a | ||
| M21 = step.clone() | ||
| M22 = 1.0 - step * step * a | ||
| else: | ||
| schur = 1.0 / (1.0 + step * step * a) | ||
| M11 = 1.0 - step * step * a * schur | ||
| M12 = -step * a * schur | ||
| M21 = step * schur | ||
| M22 = schur | ||
|
|
||
| Bu_re = torch.einsum('bth,ph->btp', x, B_re) | ||
| Bu_im = torch.einsum('bth,ph->btp', x, B_im) | ||
|
|
||
| with torch.no_grad(): | ||
| h1_re = x.new_zeros(B_t, P_t) | ||
| h1_im = x.new_zeros(B_t, P_t) | ||
| h2_re = x.new_zeros(B_t, P_t) | ||
| h2_im = x.new_zeros(B_t, P_t) | ||
| if initial_state is not None: | ||
| h1_re = initial_state[:, 0].to(x.dtype) | ||
| h2_re = initial_state[:, 1].to(x.dtype) | ||
| for t in range(T_t): | ||
| bu_re_t = Bu_re[:, t] | ||
| bu_im_t = Bu_im[:, t] | ||
| if discretization == 'IMEX': | ||
| f1_re = bu_re_t * step.unsqueeze(0) | ||
| f1_im = bu_im_t * step.unsqueeze(0) | ||
| f2_re = bu_re_t * (step * step).unsqueeze(0) | ||
| f2_im = bu_im_t * (step * step).unsqueeze(0) | ||
| else: | ||
| f1_re = M11.unsqueeze(0) * bu_re_t * step.unsqueeze(0) | ||
| f1_im = M11.unsqueeze(0) * bu_im_t * step.unsqueeze(0) | ||
| f2_re = M21.unsqueeze(0) * bu_re_t * step.unsqueeze(0) | ||
| f2_im = M21.unsqueeze(0) * bu_im_t * step.unsqueeze(0) | ||
| h1_re_n = M11.unsqueeze(0) * h1_re + M12.unsqueeze(0) * h2_re + f1_re | ||
| h1_im_n = M11.unsqueeze(0) * h1_im + M12.unsqueeze(0) * h2_im + f1_im | ||
| h2_re_n = M21.unsqueeze(0) * h1_re + M22.unsqueeze(0) * h2_re + f2_re | ||
| h2_im_n = M21.unsqueeze(0) * h1_im + M22.unsqueeze(0) * h2_im + f2_im | ||
| h1_re, h1_im = h1_re_n, h1_im_n | ||
| h2_re, h2_im = h2_re_n, h2_im_n | ||
| final_state = torch.stack([h1_re, h2_re], dim=1) | ||
| return o, final_state | ||
| else: | ||
| return FusedRecurrentLinOSSFunction.apply( | ||
| x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip, | ||
| initial_state, output_final_state, discretization, | ||
| ) |
There was a problem hiding this comment.
The logic in fused_recurrent_linoss appears to be inverted. When x.requires_grad is true (i.e., during training), it should use FusedRecurrentLinOSSFunction.apply to leverage the custom Triton kernel for the forward pass and the custom implementation for the backward pass. Instead, it falls back to a pure PyTorch implementation (_linoss_recurrent_torch), which will be slow and bypasses the custom backward. Conversely, during inference (x.requires_grad is false), it unnecessarily calls FusedRecurrentLinOSSFunction.apply, which has autograd overhead. It should call the forward kernel wrapper fused_recurrent_linoss_fwd directly for optimal performance.
The suggested change simplifies the function to always use the autograd.Function, which is the standard and correct pattern.
def fused_recurrent_linoss(
x: torch.Tensor,
B_re: torch.Tensor,
B_im: torch.Tensor,
C_re: torch.Tensor,
C_im: torch.Tensor,
a_diag: torch.Tensor,
dt: torch.Tensor,
d_skip: torch.Tensor,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
discretization: str = 'IM',
) -> tuple[torch.Tensor, torch.Tensor | None]:
r"""
Fused recurrent implementation of LinOSS (Linear Ordinary State Space).
LinOSS models a second-order ODE system discretized with either IM or IMEX methods.
Each state dimension has a 2-component state [position, velocity].
Args:
x (torch.Tensor):
Input sequence of shape `[B, T, H]`.
B_re (torch.Tensor):
Real part of input matrix B, shape `[P, H]`.
B_im (torch.Tensor):
Imaginary part of input matrix B, shape `[P, H]`.
C_re (torch.Tensor):
Real part of output matrix C, shape `[H, P]`.
C_im (torch.Tensor):
Imaginary part of output matrix C, shape `[H, P]`.
a_diag (torch.Tensor):
Diagonal state matrix (pre-relu), shape `[P]`.
dt (torch.Tensor):
Discretization step sizes (pre-sigmoid), shape `[P]`.
d_skip (torch.Tensor):
Skip connection weights, shape `[H]`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[B, 2, P]`. Default: `None`.
output_final_state (bool):
Whether to output the final state. Default: `False`.
discretization (str):
Discretization method, either 'IM' or 'IMEX'. Default: `'IM'`.
Returns:
o (torch.Tensor):
Output of shape `[B, T, H]`.
final_state (Optional[torch.Tensor]):
Final state of shape `[B, 2, P]` if `output_final_state=True`.
"""
return FusedRecurrentLinOSSFunction.apply(
x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip,
initial_state, output_final_state, discretization,
)
fla/layers/linoss.py
Outdated
| recurrent_state=recurrent_state, | ||
| conv_state=conv_state if self.use_short_conv else None, | ||
| layer_idx=self.layer_idx, | ||
| offset=i.shape[2] if len(i.shape) > 2 else 1, |
There was a problem hiding this comment.
The offset for updating the cache is calculated as i.shape[2], which corresponds to the hidden dimension (H). The offset should represent the number of new tokens, which is the sequence length. This should be i.shape[1].
| offset=i.shape[2] if len(i.shape) > 2 else 1, | |
| offset=i.shape[1], |
fla/layers/linoss.py
Outdated
| import torch.nn.functional as F | ||
|
|
||
| from fla.modules import FusedRMSNormGated, ShortConvolution | ||
| from fla.modules.activations import swiglu |
There was a problem hiding this comment.
Actionable comments posted: 9
🧹 Nitpick comments (5)
tests/ops/test_linoss.py (2)
54-54: Consider extractingos.environ['TRITON_F32_DEFAULT'] = 'ieee'into a shared fixture orconftest.py.This is repeated in all three test functions. A session- or module-scoped fixture would be cleaner and avoid global env mutation in each test body.
Also applies to: 91-91, 146-146
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_linoss.py` at line 54, Extract the repeated environment mutation os.environ['TRITON_F32_DEFAULT'] = 'ieee' into a pytest fixture (e.g., def triton_f32_default():) with scope="session" or "module" and set autouse=True (or explicitly use the fixture in the tests) so tests no longer mutate globals; in the fixture set the env var, yield, and restore the original value afterward to avoid leaking state; update test functions to remove the in-test assignment and rely on the fixture (fixture name: triton_f32_default).
82-124: Backward test validates_linoss_recurrent_torch, not the Triton kernel backward.Since
requires_grad_(True)is set,fused_recurrent_linosstakes the_linoss_recurrent_torchpath (notFusedRecurrentLinOSSFunction.apply). This is fine sinceFusedRecurrentLinOSSFunction.backwarditself delegates to_linoss_recurrent_torch, but it's worth noting that the Triton forward + torch backward round-trip isn't directly tested here.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_linoss.py` around lines 82 - 124, The test currently calls fused_recurrent_linoss(...) which, when inputs have requires_grad=True, routes to the Python `_linoss_recurrent_torch` path; to exercise the Triton custom backward use the custom autograd Function directly. In test_fused_recurrent_bwd replace the fused_recurrent_linoss(...) call with an explicit call to FusedRecurrentLinOSSFunction.apply(...) (passing the same inputs and discretization) so the Triton forward/backward implementation in FusedRecurrentLinOSSFunction is exercised.fla/models/linoss/modeling_linoss.py (1)
270-283: Improve exception handling ingenerate.Two issues flagged by Ruff:
- Line 275: The re-raised
AttributeErrorshould chain the original exception withfrom exception(B904).- Line 283:
raise exceptionshould be bareraiseto preserve the original traceback (TRY201).Proposed fix
if 'past_key_values' in str(exception): raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies", - ) + ) from exception else: - raise exception + raise🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/models/linoss/modeling_linoss.py` around lines 270 - 283, The generate method's exception handling should preserve original tracebacks and chain exceptions: in the AttributeError branch that constructs a new AttributeError, raise it using "from exception" to chain the original (i.e., raise AttributeError(...) from exception), and in the else branch re-raise the caught exception with a bare "raise" instead of "raise exception" so the original traceback is preserved; update the generate method in the class (the method named generate) accordingly.fla/ops/linoss/fused_recurrent.py (2)
214-239: Backward ofFusedRecurrentLinOSSFunctionmay be unreachable from the public API.
fused_recurrent_linossdispatches toFusedRecurrentLinOSSFunction.applyonly whenx.requires_grad is False(line 426). In that case, autograd will never invokebackward. The backward implementation (which re-runs the PyTorch recurrence) is effectively dead code when called exclusively through the public API.This isn't a bug—it's a safety net if someone calls
FusedRecurrentLinOSSFunction.applydirectly—but worth noting for maintainability.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/linoss/fused_recurrent.py` around lines 214 - 239, The backward implementation in FusedRecurrentLinOSSFunction (method backward) is effectively unreachable because fused_recurrent_linoss currently only dispatches to FusedRecurrentLinOSSFunction.apply when x.requires_grad is False; change fused_recurrent_linoss so it calls FusedRecurrentLinOSSFunction.apply when x.requires_grad is True (or unconditionally) so that autograd can register the Function and invoke backward, or alternatively add a clear comment and guard in FusedRecurrentLinOSSFunction.apply to document intended direct-use-only behavior and avoid confusion; update the dispatch logic in fused_recurrent_linoss (the call site referencing FusedRecurrentLinOSSFunction.apply) to ensure the backward path is reachable from the public API.
371-425: Large code duplication foroutput_final_statewhenx.requires_grad.When gradients are needed and
output_final_state=True, the entire recurrence is re-run (lines 397–424) just to extract the final hidden state. This duplicates the logic already in_linoss_recurrent_torchand is also inefficient:Bu_re/Bu_im(lines 394–395) are computed with grad-tracking enabled but only consumed inside atorch.no_grad()block, wasting memory on the autograd graph.Consider either:
- Extending
_linoss_recurrent_torchto optionally returnfinal_statealongside outputs, or- At minimum, moving the
Bu_re/Bu_imeinsums inside theno_gradblock.Option 2 – move einsums inside no_grad
if output_final_state: B_t, T_t, H_t = x.shape P_t = a_diag.shape[0] - a = torch.relu(a_diag) - step = torch.sigmoid(dt) - - if discretization == 'IMEX': - M11 = torch.ones_like(a) - M12 = -step * a - M21 = step.clone() - M22 = 1.0 - step * step * a - else: - schur = 1.0 / (1.0 + step * step * a) - M11 = 1.0 - step * step * a * schur - M12 = -step * a * schur - M21 = step * schur - M22 = schur - - Bu_re = torch.einsum('bth,ph->btp', x, B_re) - Bu_im = torch.einsum('bth,ph->btp', x, B_im) - with torch.no_grad(): + a = torch.relu(a_diag) + step = torch.sigmoid(dt) + + if discretization == 'IMEX': + M11 = torch.ones_like(a) + M12 = -step * a + M21 = step.clone() + M22 = 1.0 - step * step * a + else: + schur = 1.0 / (1.0 + step * step * a) + M11 = 1.0 - step * step * a * schur + M12 = -step * a * schur + M21 = step * schur + M22 = schur + + Bu_re = torch.einsum('bth,ph->btp', x, B_re) + Bu_im = torch.einsum('bth,ph->btp', x, B_im) + h1_re = x.new_zeros(B_t, P_t)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/linoss/fused_recurrent.py` around lines 371 - 425, The current code recomputes the recurrence under x.requires_grad to build final_state, and computes Bu_re/Bu_im with grad tracking but only uses them inside a torch.no_grad() block; to fix this either (preferred) extend _linoss_recurrent_torch to optionally return final_state alongside o (add a flag like return_final_state and have callers use it) or (minimal) move the Bu_re and Bu_im einsum computations into the with torch.no_grad(): block and keep the rest of the final-state-only loop inside no_grad to avoid building an unnecessary autograd graph; ensure you reference _linoss_recurrent_torch, output_final_state, Bu_re, Bu_im, and the torch.no_grad() region when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/layers/__init__.py`:
- Line 68: Add a single trailing newline to the end of fla/layers/__init__.py so
the file ends with exactly one newline character; locate the file (contains the
closing list bracket "]") and ensure there is a newline after that bracket to
satisfy the end-of-file fixer and Ruff.
In `@fla/layers/linoss.py`:
- Around line 110-114: The code indexes past_key_values with self.layer_idx
which can be None; either validate layer_idx in the constructor or guard every
access. Fix by adding a check that self.layer_idx is not None (and is an int
within range) before using it in linoss.py: replace direct uses like
past_key_values[self.layer_idx] and past_key_values.update(...,
layer_idx=self.layer_idx, ...) with guarded logic (e.g., if self.layer_idx is
None: skip indexing/update or raise a clear ValueError in __init__); reference
symbols: self.layer_idx, past_key_values, and the methods where these appear to
ensure no NoneType indexing occurs.
- Around line 34-35: Update the __init__ signature of the LinOSSAttention class:
change the parameter annotation for layer_idx from "int = None" to "int | None =
None" to explicitly allow None, and change the return annotation from "->
LinOSSAttention" to "-> None" since __init__ must not declare a class return
type; update the signature in the LinOSSAttention.__init__ definition
accordingly.
- Around line 40-42: The constructor allows expand_ratio: float | None but uses
it directly; update the constructor in the Linoss class to guard against
expand_ratio being None by normalizing it to a float before use (e.g., if
expand_ratio is None: expand_ratio = 1.0), then compute self.input_dim =
int(hidden_size * expand_ratio) and set self.ssm_size = ssm_size if ssm_size is
not None else self.input_dim; reference symbols: expand_ratio, hidden_size,
self.input_dim, self.ssm_size.
- Around line 8-11: Remove the unused imports to satisfy the linter: delete the
"import torch.nn.functional as F" and the "from fla.modules.activations import
swiglu" lines in fla/layers/linoss.py while keeping the required imports
FusedRMSNormGated and ShortConvolution intact; if swiglu is actually needed
elsewhere, instead reference it where used (e.g., in any function that should
call swiglu) or add a usage comment to justify keeping the import.
- Around line 158-164: The offset is using the wrong axis: change the update in
past_key_values.update(...) so that offset is set from the sequence length
dimension of the input tensor i (use i.shape[1]) instead of the feature dim
(i.shape[2]); locate the call to past_key_values.update in this block
(references: past_key_values.update, recurrent_state, conv_state,
self.use_short_conv, self.layer_idx, offset) and replace the offset expression
with i.shape[1] to match other attention implementations.
In `@fla/models/linoss/configuration_linoss.py`:
- Around line 71-76: Fix the typo in the warning string inside the
fuse_linear_cross_entropy branch: update the message passed to warnings.warn
(refer to the fuse_linear_cross_entropy flag and the warnings.warn call in
configuration_linoss.py) so "can improves memory efficiency" becomes "can
improve memory efficiency" while preserving the rest of the warning text and
formatting.
In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 80-86: The kernel currently only loads/stores real parts into the
initial/final state (see b_h1_re, b_h2_re and the USE_INITIAL_STATE branch
loading from h0) while b_h1_im and b_h2_im remain zeroed and never persisted, so
imaginary components are lost across segments; update the USE_INITIAL_STATE
loading to also load the imaginary parts from h0 (into b_h1_im and b_h2_im) and
ensure the code path that writes final_state writes both real and imaginary
parts in the same tensor layout as initial_state; update any state
packing/unpacking functions and the final-state emit logic so
initial_state/final_state contain both h1_re/h1_im and h2_re/h2_im in the agreed
order.
In `@tests/ops/test_linoss.py`:
- Around line 165-167: The file ends without a trailing newline causing the
end-of-file-fixer CI change; simply add a single newline character at the end of
the file (after the final assert_close call involving ref_ht and tri_ht) so the
file terminates with a newline and the pipeline stops modifying it.
---
Nitpick comments:
In `@fla/models/linoss/modeling_linoss.py`:
- Around line 270-283: The generate method's exception handling should preserve
original tracebacks and chain exceptions: in the AttributeError branch that
constructs a new AttributeError, raise it using "from exception" to chain the
original (i.e., raise AttributeError(...) from exception), and in the else
branch re-raise the caught exception with a bare "raise" instead of "raise
exception" so the original traceback is preserved; update the generate method in
the class (the method named generate) accordingly.
In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 214-239: The backward implementation in
FusedRecurrentLinOSSFunction (method backward) is effectively unreachable
because fused_recurrent_linoss currently only dispatches to
FusedRecurrentLinOSSFunction.apply when x.requires_grad is False; change
fused_recurrent_linoss so it calls FusedRecurrentLinOSSFunction.apply when
x.requires_grad is True (or unconditionally) so that autograd can register the
Function and invoke backward, or alternatively add a clear comment and guard in
FusedRecurrentLinOSSFunction.apply to document intended direct-use-only behavior
and avoid confusion; update the dispatch logic in fused_recurrent_linoss (the
call site referencing FusedRecurrentLinOSSFunction.apply) to ensure the backward
path is reachable from the public API.
- Around line 371-425: The current code recomputes the recurrence under
x.requires_grad to build final_state, and computes Bu_re/Bu_im with grad
tracking but only uses them inside a torch.no_grad() block; to fix this either
(preferred) extend _linoss_recurrent_torch to optionally return final_state
alongside o (add a flag like return_final_state and have callers use it) or
(minimal) move the Bu_re and Bu_im einsum computations into the with
torch.no_grad(): block and keep the rest of the final-state-only loop inside
no_grad to avoid building an unnecessary autograd graph; ensure you reference
_linoss_recurrent_torch, output_final_state, Bu_re, Bu_im, and the
torch.no_grad() region when making the change.
In `@tests/ops/test_linoss.py`:
- Line 54: Extract the repeated environment mutation
os.environ['TRITON_F32_DEFAULT'] = 'ieee' into a pytest fixture (e.g., def
triton_f32_default():) with scope="session" or "module" and set autouse=True (or
explicitly use the fixture in the tests) so tests no longer mutate globals; in
the fixture set the env var, yield, and restore the original value afterward to
avoid leaking state; update test functions to remove the in-test assignment and
rely on the fixture (fixture name: triton_f32_default).
- Around line 82-124: The test currently calls fused_recurrent_linoss(...)
which, when inputs have requires_grad=True, routes to the Python
`_linoss_recurrent_torch` path; to exercise the Triton custom backward use the
custom autograd Function directly. In test_fused_recurrent_bwd replace the
fused_recurrent_linoss(...) call with an explicit call to
FusedRecurrentLinOSSFunction.apply(...) (passing the same inputs and
discretization) so the Triton forward/backward implementation in
FusedRecurrentLinOSSFunction is exercised.
| layer_idx: int = None, | ||
| ) -> LinOSSAttention: |
There was a problem hiding this comment.
Type annotation issues on __init__.
layer_idx: int = Noneshould beint | None = None(PEP 484 prohibits implicitOptional, also flagged by Ruff RUF013).- The return type
-> LinOSSAttentionis incorrect for__init__; it should be-> None.
Proposed fix
- layer_idx: int = None,
- ) -> LinOSSAttention:
+ layer_idx: int | None = None,
+ ) -> None:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| layer_idx: int = None, | |
| ) -> LinOSSAttention: | |
| layer_idx: int | None = None, | |
| ) -> None: |
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 34-34: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/layers/linoss.py` around lines 34 - 35, Update the __init__ signature of
the LinOSSAttention class: change the parameter annotation for layer_idx from
"int = None" to "int | None = None" to explicitly allow None, and change the
return annotation from "-> LinOSSAttention" to "-> None" since __init__ must not
declare a class return type; update the signature in the
LinOSSAttention.__init__ definition accordingly.
| self.expand_ratio = expand_ratio | ||
| self.input_dim = int(hidden_size * expand_ratio) | ||
| self.ssm_size = ssm_size if ssm_size is not None else self.input_dim |
There was a problem hiding this comment.
expand_ratio typed as float | None but used without None guard.
expand_ratio defaults to 1. but the type signature allows None. Line 41 does int(hidden_size * expand_ratio) which will raise TypeError if None is passed explicitly. Either tighten the type to float or add a guard.
Proposed fix (guard)
self.expand_ratio = expand_ratio
- self.input_dim = int(hidden_size * expand_ratio)
+ self.input_dim = int(hidden_size * (expand_ratio if expand_ratio is not None else 1.0))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/layers/linoss.py` around lines 40 - 42, The constructor allows
expand_ratio: float | None but uses it directly; update the constructor in the
Linoss class to guard against expand_ratio being None by normalizing it to a
float before use (e.g., if expand_ratio is None: expand_ratio = 1.0), then
compute self.input_dim = int(hidden_size * expand_ratio) and set self.ssm_size =
ssm_size if ssm_size is not None else self.input_dim; reference symbols:
expand_ratio, hidden_size, self.input_dim, self.ssm_size.
fla/layers/linoss.py
Outdated
| mode = self.mode | ||
|
|
||
| last_state = None | ||
| if past_key_values is not None and len(past_key_values) > self.layer_idx: | ||
| last_state = past_key_values[self.layer_idx] |
There was a problem hiding this comment.
Potential NoneType access if layer_idx is None.
If layer_idx is not set (defaults to None), the expression past_key_values[self.layer_idx] at line 114 and past_key_values.update(..., layer_idx=self.layer_idx, ...) at line 162 will use None as an index. In practice, the model builder always passes layer_idx, but the constructor allows None, creating a latent bug if someone instantiates the layer standalone.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/layers/linoss.py` around lines 110 - 114, The code indexes
past_key_values with self.layer_idx which can be None; either validate layer_idx
in the constructor or guard every access. Fix by adding a check that
self.layer_idx is not None (and is an int within range) before using it in
linoss.py: replace direct uses like past_key_values[self.layer_idx] and
past_key_values.update(..., layer_idx=self.layer_idx, ...) with guarded logic
(e.g., if self.layer_idx is None: skip indexing/update or raise a clear
ValueError in __init__); reference symbols: self.layer_idx, past_key_values, and
the methods where these appear to ensure no NoneType indexing occurs.
| if fuse_linear_cross_entropy: | ||
| warnings.warn( | ||
| "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency " | ||
| "at the potential cost of reduced precision. " | ||
| "If you observe issues like loss divergence, consider disabling this setting.", | ||
| ) |
There was a problem hiding this comment.
Typo in warning message.
Line 73: "can improves memory efficiency" → "can improve memory efficiency".
Proposed fix
- "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
+ "`fuse_linear_cross_entropy` is enabled, which can improve memory efficiency "📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if fuse_linear_cross_entropy: | |
| warnings.warn( | |
| "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency " | |
| "at the potential cost of reduced precision. " | |
| "If you observe issues like loss divergence, consider disabling this setting.", | |
| ) | |
| if fuse_linear_cross_entropy: | |
| warnings.warn( | |
| "`fuse_linear_cross_entropy` is enabled, which can improve memory efficiency " | |
| "at the potential cost of reduced precision. " | |
| "If you observe issues like loss divergence, consider disabling this setting.", | |
| ) |
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 72-72: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/models/linoss/configuration_linoss.py` around lines 71 - 76, Fix the typo
in the warning string inside the fuse_linear_cross_entropy branch: update the
message passed to warnings.warn (refer to the fuse_linear_cross_entropy flag and
the warnings.warn call in configuration_linoss.py) so "can improves memory
efficiency" becomes "can improve memory efficiency" while preserving the rest of
the warning text and formatting.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/ops/test_linoss.py (1)
54-54: Consider scopingTRITON_F32_DEFAULTat the module or fixture level.Setting
os.environ['TRITON_F32_DEFAULT']inside each test function is repetitive and relies on the process-wide environment. A@pytest.fixture(autouse=True)or a module-level assignment would be cleaner and less error-prone.Also applies to: 91-91, 146-146
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_linoss.py` at line 54, Multiple tests set os.environ['TRITON_F32_DEFAULT'] repetitively inside test functions; replace that with a single module-scoped solution by adding a pytest fixture with `@pytest.fixture`(autouse=True) that sets os.environ['TRITON_F32_DEFAULT'] = 'ieee' (and restores previous value on teardown) or by assigning it once at module import time, so remove the per-test assignments and centralize the environment setup used by the tests referencing TRITON_F32_DEFAULT.fla/models/linoss/modeling_linoss.py (1)
270-283: Useraise ... fromfor chained exceptions.When re-raising as a different
AttributeError, chain withfrom exceptionso the original traceback is preserved. Also, bareraise exceptionon line 283 should be justraise.Proposed fix
- raise AttributeError( + raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies", - ) + ) from exception else: - raise exception + raise🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/models/linoss/modeling_linoss.py` around lines 270 - 283, In the generate method replace the unchained re-raise by chaining the new AttributeError to the original (use "raise AttributeError(...) from exception") so the original traceback is preserved when you raise the custom message for 'past_key_values', and change the final "raise exception" branch to a bare "raise" to re-raise the original exception; refer to the generate method and the AttributeError handling block in modeling_linoss.py to locate the changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/models/linoss/modeling_linoss.py`:
- Around line 320-339: The forward logic creates a local criterion when
self.criterion is None (using FusedLinearCrossEntropyLoss,
FusedCrossEntropyLoss, or nn.CrossEntropyLoss) but never caches it; update the
block where criterion is instantiated (the branch that sets criterion when
getattr(self, 'criterion', None) is None) to assign the new instance to
self.criterion (e.g., self.criterion = criterion) so subsequent calls reuse the
cached criterion instead of recreating it each forward pass; ensure subsequent
code uses self.criterion where appropriate.
---
Duplicate comments:
In `@fla/layers/linoss.py`:
- Around line 32-33: Update the LinOSSAttention.__init__ signature: change the
parameter annotation from "layer_idx: int = None" to "layer_idx: int | None =
None" to satisfy the optional-int typing (RUF013), and change the __init__
return annotation from "-> LinOSSAttention" to "-> None" so the constructor has
the correct return type; locate these in the LinOSSAttention class __init__
definition and update the type hints accordingly.
- Around line 156-162: The offset is incorrectly using the feature dim
(i.shape[2]) instead of sequence length; update the past_key_values.update call
so offset uses i.shape[1] when available (e.g. offset=i.shape[1] if len(i.shape)
> 1 else 1) while keeping the other fields (recurrent_state, conv_state when
self.use_short_conv, layer_idx) unchanged—locate the past_key_values.update
invocation in linoss.py and replace the offset expression accordingly.
- Around line 38-40: expand_ratio is annotated as float | None but used directly
in computing input_dim, which will raise TypeError if expand_ratio is None;
update the initialization in the LinOSS (or containing) class so you first check
if self.expand_ratio is None and set self.input_dim = hidden_size in that case,
otherwise compute self.input_dim = int(hidden_size * self.expand_ratio); then
compute self.ssm_size = ssm_size if ssm_size is not None else self.input_dim to
preserve the existing fallback behavior. Ensure you reference and update the
assignments to self.expand_ratio, self.input_dim, and self.ssm_size in
linoss.py.
- Around line 110-112: last_state indexing can raise if self.layer_idx is None;
update the guard around past_key_values to ensure self.layer_idx is an int and
within range before indexing (e.g., check isinstance(self.layer_idx, int) and 0
<= self.layer_idx < len(past_key_values)) so that last_state =
past_key_values[self.layer_idx] only runs when safe; modify the block
referencing past_key_values and self.layer_idx (the variables last_state,
past_key_values, and self.layer_idx) accordingly to handle None or invalid
layer_idx values.
In `@fla/models/linoss/configuration_linoss.py`:
- Line 73: Replace the typo in the user-facing string that mentions
`fuse_linear_cross_entropy`: change "can improves" to "can improve" wherever
that phrase appears (e.g., the message or docstring referencing
`fuse_linear_cross_entropy` in configuration_linoss.py) so the sentence reads
correctly.
---
Nitpick comments:
In `@fla/models/linoss/modeling_linoss.py`:
- Around line 270-283: In the generate method replace the unchained re-raise by
chaining the new AttributeError to the original (use "raise AttributeError(...)
from exception") so the original traceback is preserved when you raise the
custom message for 'past_key_values', and change the final "raise exception"
branch to a bare "raise" to re-raise the original exception; refer to the
generate method and the AttributeError handling block in modeling_linoss.py to
locate the changes.
In `@tests/ops/test_linoss.py`:
- Line 54: Multiple tests set os.environ['TRITON_F32_DEFAULT'] repetitively
inside test functions; replace that with a single module-scoped solution by
adding a pytest fixture with `@pytest.fixture`(autouse=True) that sets
os.environ['TRITON_F32_DEFAULT'] = 'ieee' (and restores previous value on
teardown) or by assigning it once at module import time, so remove the per-test
assignments and centralize the environment setup used by the tests referencing
TRITON_F32_DEFAULT.
| loss, logits = None, None | ||
| if not self.config.fuse_linear_cross_entropy or labels is None: | ||
| logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) | ||
| if labels is not None: | ||
| if getattr(self, 'criterion', None) is None: | ||
| if self.config.fuse_linear_cross_entropy: | ||
| criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) | ||
| elif self.config.fuse_cross_entropy: | ||
| criterion = FusedCrossEntropyLoss(inplace_backward=True) | ||
| else: | ||
| criterion = nn.CrossEntropyLoss() | ||
| else: | ||
| criterion = self.criterion | ||
| labels = labels.to(hidden_states.device) | ||
| labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) | ||
| if self.config.fuse_linear_cross_entropy: | ||
| loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) | ||
| else: | ||
| loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) | ||
| loss = l2_warp(loss, logits) if self.config.use_l2warp else loss |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n fla/models/linoss/modeling_linoss.py | sed -n '315,345p'Repository: fla-org/flash-linear-attention
Length of output: 1794
🏁 Script executed:
rg -n 'self\.criterion\s*=' fla/models/ -g '*.py' -A 2 | head -50Repository: fla-org/flash-linear-attention
Length of output: 2885
🏁 Script executed:
rg -n 'criterion\s*=' fla/models/linoss/modeling_linoss.py -B 2 -A 2Repository: fla-org/flash-linear-attention
Length of output: 978
Add self.criterion = criterion assignment after line 330 to cache the criterion.
The criterion is instantiated locally (lines 326–330) but never assigned to self.criterion, causing it to be re-created on every forward pass with labels. Line 248 initializes self.criterion = None, and line 324 checks for it, indicating the intent to cache. This pattern matches all other models in the codebase (gated_deltaproduct, transformer, rwkv7, etc.), which properly assign the criterion after creation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/models/linoss/modeling_linoss.py` around lines 320 - 339, The forward
logic creates a local criterion when self.criterion is None (using
FusedLinearCrossEntropyLoss, FusedCrossEntropyLoss, or nn.CrossEntropyLoss) but
never caches it; update the block where criterion is instantiated (the branch
that sets criterion when getattr(self, 'criterion', None) is None) to assign the
new instance to self.criterion (e.g., self.criterion = criterion) so subsequent
calls reuse the cached criterion instead of recreating it each forward pass;
ensure subsequent code uses self.criterion where appropriate.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
fla/ops/linoss/fused_recurrent.py (2)
88-98: Inner loop overHis sequential in the kernel — expected for this architecture but worth documenting.The
for h in range(H)loops at lines 93 and 122 iterate element-by-element over the hidden dimension within each Triton program. For largeH(e.g., 1024+), this may become a bottleneck since Triton unrolls these into sequential scalar operations. The P-dimension is parallelized across programs and theatomic_addat line 130 handles cross-block output accumulation.This is a known pattern in the fla codebase's fused recurrent kernels and acceptable for now, but a brief comment noting this design choice would help future contributors.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/linoss/fused_recurrent.py` around lines 88 - 98, Add a brief in-code comment above the inner "for h in range(H)" loop in fused_recurrent.py (the loop that loads b_u_h, B_re/B_im and accumulates b_bu_re/b_bu_im) stating that this inner H loop is intentionally sequential in the Triton kernel, that P-dimension is parallelized across programs and cross-block accumulation is handled via atomic_add, and that this design may be a bottleneck for very large H (e.g., 1024+) but is acceptable for this architecture; reference the variables/operations b_u_h, B_re/B_im loads, b_bu_re/b_bu_im accumulation and the atomic_add used for output accumulation so future contributors understand the tradeoff.
371-380: Gradient dispatch checks onlyx.requires_grad, which may skip autograd for parameter-only gradients.If
xis detached (e.g., in certain eval-with-grad or mixed scenarios), the else-branch callsfused_recurrent_linoss_fwddirectly, bypassingFusedRecurrentLinOSSFunction.apply. This means gradients forB_re,C_re,a_diag,dt,d_skip, etc. would not be computed even if those tensors require grad.In typical training this is fine (since
xcomes from ann.Linearand carries grad), but a more robust check would be:- if x.requires_grad: + if torch.is_grad_enabled():or check any input's
requires_grad. This aligns with a more defensive dispatch pattern.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/linoss/fused_recurrent.py` around lines 371 - 380, The current dispatch only checks x.requires_grad before choosing between FusedRecurrentLinOSSFunction.apply and fused_recurrent_linoss_fwd, which misses cases where other tensors (B_re, B_im, C_re, C_im, a_diag, dt, d_skip, initial_state, etc.) require gradients; change the condition to test whether any of the relevant inputs require grad (e.g., any(t is not None and getattr(t, "requires_grad", False) for t in (x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip, initial_state))) and call FusedRecurrentLinOSSFunction.apply when true, otherwise call fused_recurrent_linoss_fwd to ensure parameter-only gradients are preserved.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 214-239: The backward currently recomputes the forward via
_linoss_recurrent_torch under torch.enable_grad() and calls o.backward(do), but
this can produce incorrect gradients if the recomputed forward uses a different
dtype/accumulation than the Triton kernel and it entirely ignores the incoming
dht (gradient w.r.t. final state). Fix by ensuring the recomputation uses the
exact same dtype and accumulation settings as the Triton kernel (cast tensors
like x, B_re, B_im, C_re, C_im, a_diag, dt, d_skip to the original kernel's
dtype and set deterministic accumulation if needed) and by including the
final-state gradient: have the recomputed forward return both outputs and
final_state, then call backward with both gradients (o.backward((do, dht)) or
use torch.autograd.grad to request gradients w.r.t. inputs including
initial_state while passing dht for the final state), and return the resulting
grads for initial_state (and all other saved tensors). Reference symbols:
backward (staticmethod), ctx.saved_tensors, dht, _linoss_recurrent_torch,
initial_state.
---
Duplicate comments:
In `@fla/layers/linoss.py`:
- Around line 32-33: The __init__ signature for class LinOSSAttention is
incorrectly annotated: change the parameter annotation from "layer_idx: int =
None" to "layer_idx: int | None = None" (or use Optional[int]) and update the
__init__ return annotation from "-> LinOSSAttention" to "-> None"; locate and
edit the __init__ method of LinOSSAttention to apply these fixes.
- Around line 38-40: The constructor currently multiplies hidden_size by
expand_ratio without guarding against expand_ratio being None; change the logic
in the Linoss initializer so expand_ratio is normalized first (e.g.,
effective_expand = expand_ratio if expand_ratio is not None else 1.0) and then
compute self.input_dim = int(hidden_size * effective_expand); keep
self.expand_ratio set to the original value or the normalized value per your API
choice, and ensure self.ssm_size still defaults to self.input_dim when ssm_size
is None.
- Around line 110-112: The code assumes self.layer_idx is an int; guard against
None by checking it before indexing or updating past_key_values: change the
condition to check that self.layer_idx is not None and is an int within range
(e.g., if past_key_values is not None and self.layer_idx is not None and 0 <=
self.layer_idx < len(past_key_values): last_state =
past_key_values[self.layer_idx]) and similarly ensure any call that does
past_key_values.update(..., layer_idx=self.layer_idx) only passes layer_idx when
self.layer_idx is not None (or use a different update key), so references to
self.layer_idx (in methods like where last_state is set and where
past_key_values is updated) never attempt to index with None.
In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 80-86: The kernel initializes b_h1_im and b_h2_im to zeros but
only loads/stores the real components when USE_INITIAL_STATE / final-state
handling occurs, so imaginary state is lost across segments; update the
initial-state load to also load the imaginary components into b_h1_im and
b_h2_im (use the appropriate h0 offsets for the imaginary parts, mirroring the
real-part loads) and likewise update the final-state store logic to write both
real and imaginary parts; apply the same fix to the torch fallback code path
where h1_im/h2_im are currently left zero so both real and imaginary components
are persisted across segments.
---
Nitpick comments:
In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 88-98: Add a brief in-code comment above the inner "for h in
range(H)" loop in fused_recurrent.py (the loop that loads b_u_h, B_re/B_im and
accumulates b_bu_re/b_bu_im) stating that this inner H loop is intentionally
sequential in the Triton kernel, that P-dimension is parallelized across
programs and cross-block accumulation is handled via atomic_add, and that this
design may be a bottleneck for very large H (e.g., 1024+) but is acceptable for
this architecture; reference the variables/operations b_u_h, B_re/B_im loads,
b_bu_re/b_bu_im accumulation and the atomic_add used for output accumulation so
future contributors understand the tradeoff.
- Around line 371-380: The current dispatch only checks x.requires_grad before
choosing between FusedRecurrentLinOSSFunction.apply and
fused_recurrent_linoss_fwd, which misses cases where other tensors (B_re, B_im,
C_re, C_im, a_diag, dt, d_skip, initial_state, etc.) require gradients; change
the condition to test whether any of the relevant inputs require grad (e.g.,
any(t is not None and getattr(t, "requires_grad", False) for t in (x, B_re,
B_im, C_re, C_im, a_diag, dt, d_skip, initial_state))) and call
FusedRecurrentLinOSSFunction.apply when true, otherwise call
fused_recurrent_linoss_fwd to ensure parameter-only gradients are preserved.
|
Hi @Phoenix8215, thanks for this great contribution bringing LinOSS to FLA! 🎉 I noticed that the current implementation primarily focuses on the Would you or anyone from the community be interested in implementing these? I am happy to help review or collaborate! 🙏 |
Thanks for the suggestions!🥰 Both features are now implemented. Let me know if further refinements would be helpful. |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
fla/ops/linoss/fused_recurrent.py (1)
268-268:His unpacked but never used — prefix with_.Ruff RUF059 flags this.
His not referenced in the body of_linoss_recurrent_torch; the einsum contractions are expressed with string subscripts, not via the variable.♻️ Proposed fix
- Bat, T, H = x.shape + Bat, T, _ = x.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/linoss/fused_recurrent.py` at line 268, The tuple unpacking in _linoss_recurrent_torch currently does "Bat, T, H = x.shape" but H is never used; rename the third variable to a throwaway name (e.g., _ or _H) to satisfy RUF059. Update the unpacking in the _linoss_recurrent_torch function from "Bat, T, H = x.shape" to "Bat, T, _ = x.shape" (or "_H") and ensure no other references to H exist in that function.fla/ops/linoss/chunk.py (1)
66-66:His unpacked but never used — prefix with_.Same Ruff RUF059 pattern as in
_linoss_recurrent_torch:His assigned fromx.shapebut not referenced anywhere in the function body.♻️ Proposed fix
- Bat, T, H = x.shape + Bat, T, _ = x.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/linoss/chunk.py` at line 66, The tuple unpacking assigns Bat, T, H = x.shape but H is never used; update the unpacking in the same function (the line with "Bat, T, H = x.shape" in fla.ops.linoss.chunk.py) to use a prefixed underscore for the unused dimension (e.g., Bat, T, _H or Bat, T, _) to satisfy the Ruff RUF059 rule and clearly mark the variable as intentionally unused; mirror the same change pattern used in _linoss_recurrent_torch if present.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/linoss/fused_recurrent.py`:
- Around line 244-248: The backward is currently returning None for
initial_state which drops gradients for a learnable or non-detached
initial_state; update the return tuple in the backward of
_linoss_recurrent_torch (the block that currently returns x.grad, B_re.grad,
B_im.grad, C_re.grad, C_im.grad, a_diag.grad, dt.grad, d_skip.grad, None, None,
None) to return the gradient for initial_state in the correct slot (replace the
first of those trailing Nones with the gradient tensor computed for
initial_state or initial_state.grad), preserving the original argument order so
callers receive the initial_state gradient when it requires_grad. Ensure you
reference the saved/rewound initial_state used in the forward so you return its
grad consistently with other saved inputs.
In `@fla/ops/linoss/naive.py`:
- Line 95: The multiplication uses d_skip with x (which is cast to float32),
causing dtype mismatches in half/bf16 contexts; update the expression that
computes y_t to cast d_skip to float (e.g., d_skip.float() or
d_skip.to(x.dtype)) before multiplying by x[:, t] so the types match—modify the
line where y_t is computed (references: y_t, d_skip, x, h2, C_complex) similar
to how _linoss_recurrent_torch and chunk_linoss explicitly cast d_skip to float.
---
Duplicate comments:
In `@fla/layers/linoss.py`:
- Around line 32-33: Update the type annotations in the LinOSSAttention class
initializer: change the parameter annotation from "layer_idx: int = None" to
"layer_idx: int | None = None" to make the optionality explicit per PEP 484, and
correct the __init__ signature return annotation from "-> LinOSSAttention" to
"-> None". Adjust only the annotations on the __init__ method (the parameter
named layer_idx and the method return) so function name __init__ and class
LinOSSAttention remain unchanged.
- Around line 38-40: expand_ratio is allowed to be None but you immediately
multiply it in the input_dim computation (self.input_dim = int(hidden_size *
expand_ratio)), causing a TypeError; guard against None by normalizing
expand_ratio before use (e.g., set self.expand_ratio = 1.0 if expand_ratio is
None) and then compute self.input_dim and self.ssm_size using the normalized
value (ensure ssm_size fallback still uses the computed self.input_dim when
ssm_size is None).
- Around line 110-112: self.layer_idx may be None which causes a TypeError when
used in comparisons and passed into past_key_values; to fix, validate and
normalize layer_idx in the class __init__ (raise a ValueError if layer_idx is
None or not an int) so self.layer_idx is always a valid integer, and then the
code that reads last_state (the block using self.layer_idx to index
past_key_values) and the call to past_key_values.update(...,
layer_idx=self.layer_idx, ...) will no longer receive None; alternatively, if
you prefer not to raise, guard the usage site by checking "if self.layer_idx is
not None and len(past_key_values) > self.layer_idx:" and only pass layer_idx
into past_key_values.update when it is not None.
---
Nitpick comments:
In `@fla/ops/linoss/chunk.py`:
- Line 66: The tuple unpacking assigns Bat, T, H = x.shape but H is never used;
update the unpacking in the same function (the line with "Bat, T, H = x.shape"
in fla.ops.linoss.chunk.py) to use a prefixed underscore for the unused
dimension (e.g., Bat, T, _H or Bat, T, _) to satisfy the Ruff RUF059 rule and
clearly mark the variable as intentionally unused; mirror the same change
pattern used in _linoss_recurrent_torch if present.
In `@fla/ops/linoss/fused_recurrent.py`:
- Line 268: The tuple unpacking in _linoss_recurrent_torch currently does "Bat,
T, H = x.shape" but H is never used; rename the third variable to a throwaway
name (e.g., _ or _H) to satisfy RUF059. Update the unpacking in the
_linoss_recurrent_torch function from "Bat, T, H = x.shape" to "Bat, T, _ =
x.shape" (or "_H") and ensure no other references to H exist in that function.
| return ( | ||
| x.grad, B_re.grad, B_im.grad, C_re.grad, C_im.grad, | ||
| a_diag.grad, dt.grad, d_skip.grad, | ||
| None, None, None, | ||
| ) |
There was a problem hiding this comment.
Gradient for initial_state is unconditionally None — silently wrong for learnable initial state or truncated BPTT.
initial_state is the only saved input that is neither detached nor given requires_grad_(True) before the re-run of _linoss_recurrent_torch (all other eight inputs are explicitly prepared at lines 226–233). Consequently, even if a caller passes initial_state with requires_grad=True (learnable h₀, or gradient passed back from a later segment), that gradient is silently discarded.
Standard inference and single-pass training (where initial_state is None or a detached cache tensor) are unaffected, but the issue would produce silent wrong gradients for truncated BPTT or learnable-h₀ experiments.
🛡️ Minimal fix to propagate `initial_state` gradient
with torch.enable_grad():
x = x.detach().requires_grad_(True)
...
d_skip = d_skip.detach().requires_grad_(True)
+ if initial_state is not None:
+ initial_state = initial_state.detach().requires_grad_(True)
o, ht = _linoss_recurrent_torch(...)
...
torch.autograd.backward(targets, grads)
return (
x.grad, B_re.grad, B_im.grad, C_re.grad, C_im.grad,
a_diag.grad, dt.grad, d_skip.grad,
- None, None, None,
+ initial_state.grad if initial_state is not None else None, None, None,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/linoss/fused_recurrent.py` around lines 244 - 248, The backward is
currently returning None for initial_state which drops gradients for a learnable
or non-detached initial_state; update the return tuple in the backward of
_linoss_recurrent_torch (the block that currently returns x.grad, B_re.grad,
B_im.grad, C_re.grad, C_im.grad, a_diag.grad, dt.grad, d_skip.grad, None, None,
None) to return the gradient for initial_state in the correct slot (replace the
first of those trailing Nones with the gradient tensor computed for
initial_state or initial_state.grad), preserving the original argument order so
callers receive the initial_state gradient when it requires_grad. Ensure you
reference the saved/rewound initial_state used in the forward so you return its
grad consistently with other saved inputs.
| h1 = h1_new | ||
| h2 = h2_new | ||
|
|
||
| y_t = (h2 @ C_complex.t()).real + d_skip.unsqueeze(0) * x[:, t] |
There was a problem hiding this comment.
d_skip missing .float() cast — dtype error in half-precision contexts.
x is cast to float32 at line 42, but d_skip is used raw. With model.half() or bf16 parameters, d_skip (fp16/bf16) multiplied against x[:, t] (fp32) triggers a PyTorch dtype-mismatch error. Both _linoss_recurrent_torch (line 330 of fused_recurrent.py) and chunk_linoss (line 205 of chunk.py) explicitly call .float() on d_skip.
🐛 Proposed fix
- y_t = (h2 @ C_complex.t()).real + d_skip.unsqueeze(0) * x[:, t]
+ y_t = (h2 @ C_complex.t()).real + d_skip.float().unsqueeze(0) * x[:, t]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| y_t = (h2 @ C_complex.t()).real + d_skip.unsqueeze(0) * x[:, t] | |
| y_t = (h2 @ C_complex.t()).real + d_skip.float().unsqueeze(0) * x[:, t] |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/linoss/naive.py` at line 95, The multiplication uses d_skip with x
(which is cast to float32), causing dtype mismatches in half/bf16 contexts;
update the expression that computes y_t to cast d_skip to float (e.g.,
d_skip.float() or d_skip.to(x.dtype)) before multiplying by x[:, t] so the types
match—modify the line where y_t is computed (references: y_t, d_skip, x, h2,
C_complex) similar to how _linoss_recurrent_torch and chunk_linoss explicitly
cast d_skip to float.
|
Hi @zhiyuan1i , just checking whether you had a chance to review the recent updates. 😙😙Happy to address any comments. |
|
Hi @zhiyuan1i, @yzhangcs , could you please help review this PR when you have time? |
|
Hi @Phoenix8215, thanks for the PR and congrats on the ICLR Oral! 🎉 I'm wondering if we could adjust the implementation approach: fused_recurrent for inference, chunk for training - Currently the fused version has PyTorch autograd backward which doesn't fit FLA's philosophy. Maybe we can:
Also, I might be misunderstanding this - does LinOSS not support TensorCore-parallel training? The recurrent nature makes it hard to parallelize across sequence dimension, right? Would love to hear your thoughts on the chunk-based training approach. Thanks! |
Add support for LinOSS (Linear Ordinary State Space) model (ICLR 2025 Oral), a second-order state space model that discretizes a damped oscillator ODE using implicit midpoint (IM) or implicit-explicit (IMEX) methods.


Reference implementation: https://github.com/tk-rusch/linoss
Summary by CodeRabbit
New Features
Tests