Skip to content

[Bug] Torch backend: RMSNormalization crashes with unsorted contiguous axes #22201

@amadhan882

Description

@amadhan882

Description

When using the Torch backend, RMSNormalization fails if the axis
argument contains contiguous but unsorted axes (e.g. [-1, -2]).

Sorted axes (e.g. [-2, -1]) work correctly.

Unsorted contiguous axes should behave identically to sorted axes,
but currently result in a runtime shape mismatch error in the Torch fast-path.


Minimal Reproduction

import os
import numpy as np

os.environ["KERAS_BACKEND"] = "torch"

import keras
from keras import layers

print("Backend:", keras.backend.backend())

x = np.random.randn(2, 3, 4).astype("float32")

# Works
rms_sorted = layers.RMSNormalization(axis=[-2, -1])
rms_sorted(x)
print("Sorted axis works.")

# Crashes
rms_unsorted = layers.RMSNormalization(axis=[-1, -2])
rms_unsorted(x)

Observed Behavior

With Torch backend:

  • axis=[-2, -1] works correctly.
  • axis=[-1, -2] raises a runtime error:
RuntimeError: The size of tensor a (4) must match the size of tensor b (3)
at non-singleton dimension 2

Input shape:

(2, 3, 4)

Expected Behavior

Unsorted but contiguous axes (e.g. [-1, -2]) should behave
identically to sorted axes (e.g. [-2, -1]).

There should be no crash, and both configurations should produce
the same output.


Root Cause

normalized_shape is constructed using the axis order as provided.

For axis=[-1, -2], this produces:

normalized_shape = (4, 3)

However, the logical trailing dimensions are (3, 4), leading to
weight shape mismatch in the Torch fast-path and a backend crash.

Sorting axis before constructing normalized_shape
resolves the issue.


System Information (Google Colab)

==============================
      System Info
==============================
OS              : Linux 6.6.105+
Python version  : 3.12.12
Keras version   : 3.10.0
Keras Backend   : torch

==============================
      GPU Info
==============================
GPU Model       : Tesla T4
CUDA Available  : Yes
CUDA Version    : 12.8

==============================
    Library Versions
==============================
jax                                      0.7.2
jax-cuda12-pjrt                          0.7.2
jax-cuda12-plugin                        0.7.2
jaxlib                                   0.7.2
numpy                                    2.0.2
tensorflow                               2.19.0
tensorflow-datasets                      4.9.9
tensorflow_decision_forests              1.12.0
tensorflow-hub                           0.16.1
tensorflow-metadata                      1.17.3
tensorflow-probability                   0.25.0
tensorflow-text                          2.19.0
torch                                    2.9.0+cu128
torchao                                  0.10.0
torchaudio                               2.9.0+cu128
torchcodec                               0.8.0+cu128
torchdata                                0.11.0
torchsummary                             1.5.1
torchtune                                0.6.1
torchvision                              0.24.0+cu128

Note: The issue is reproducible when KERAS_BACKEND=torch is set.


Conclusion

This is a Torch backend fast-path bug where unsorted but contiguous axes
lead to incorrect normalized_shape construction, causing weight shape
mismatch and runtime failure.

The behavior is inconsistent with sorted axes and violates expected
axis-order invariance for contiguous dimensions.

Sorting axis before constructing normalized_shape
fully resolves the issue.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions