-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
Description
When using the Torch backend, RMSNormalization fails if the axis
argument contains contiguous but unsorted axes (e.g. [-1, -2]).
Sorted axes (e.g. [-2, -1]) work correctly.
Unsorted contiguous axes should behave identically to sorted axes,
but currently result in a runtime shape mismatch error in the Torch fast-path.
Minimal Reproduction
import os
import numpy as np
os.environ["KERAS_BACKEND"] = "torch"
import keras
from keras import layers
print("Backend:", keras.backend.backend())
x = np.random.randn(2, 3, 4).astype("float32")
# Works
rms_sorted = layers.RMSNormalization(axis=[-2, -1])
rms_sorted(x)
print("Sorted axis works.")
# Crashes
rms_unsorted = layers.RMSNormalization(axis=[-1, -2])
rms_unsorted(x)Observed Behavior
With Torch backend:
axis=[-2, -1]works correctly.axis=[-1, -2]raises a runtime error:
RuntimeError: The size of tensor a (4) must match the size of tensor b (3)
at non-singleton dimension 2
Input shape:
(2, 3, 4)
Expected Behavior
Unsorted but contiguous axes (e.g. [-1, -2]) should behave
identically to sorted axes (e.g. [-2, -1]).
There should be no crash, and both configurations should produce
the same output.
Root Cause
normalized_shape is constructed using the axis order as provided.
For axis=[-1, -2], this produces:
normalized_shape = (4, 3)
However, the logical trailing dimensions are (3, 4), leading to
weight shape mismatch in the Torch fast-path and a backend crash.
Sorting axis before constructing normalized_shape
resolves the issue.
System Information (Google Colab)
==============================
System Info
==============================
OS : Linux 6.6.105+
Python version : 3.12.12
Keras version : 3.10.0
Keras Backend : torch
==============================
GPU Info
==============================
GPU Model : Tesla T4
CUDA Available : Yes
CUDA Version : 12.8
==============================
Library Versions
==============================
jax 0.7.2
jax-cuda12-pjrt 0.7.2
jax-cuda12-plugin 0.7.2
jaxlib 0.7.2
numpy 2.0.2
tensorflow 2.19.0
tensorflow-datasets 4.9.9
tensorflow_decision_forests 1.12.0
tensorflow-hub 0.16.1
tensorflow-metadata 1.17.3
tensorflow-probability 0.25.0
tensorflow-text 2.19.0
torch 2.9.0+cu128
torchao 0.10.0
torchaudio 2.9.0+cu128
torchcodec 0.8.0+cu128
torchdata 0.11.0
torchsummary 1.5.1
torchtune 0.6.1
torchvision 0.24.0+cu128
Note: The issue is reproducible when KERAS_BACKEND=torch is set.
Conclusion
This is a Torch backend fast-path bug where unsorted but contiguous axes
lead to incorrect normalized_shape construction, causing weight shape
mismatch and runtime failure.
The behavior is inconsistent with sorted axes and violates expected
axis-order invariance for contiguous dimensions.
Sorting axis before constructing normalized_shape
fully resolves the issue.