-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
Description
In the NumPy backend, the scatter implementation currently uses a pure Python for loop to update values at target indices. This introduces a significant performance bottleneck when the number of updates is large (e.g., 10^6 or more).
By replacing the scalar loop with NumPy's native np.add.at logic, we can achieve a ~87x speedup, bringing the NumPy backend's efficiency closer to other tensor backends for large-scale operations.
Minimal Reproduction & Benchmark
import os
import time
import numpy as np
from keras.src.backend.numpy.core import scatter as original_scatter
os.environ["KERAS_BACKEND"] = "numpy"
# Proposed Fix
def optimized_scatter(indices, values, shape):
indices = np.array(indices)
values = np.array(values)
zeros = np.zeros(shape, dtype=values.dtype)
idx = tuple(indices.reshape(-1, indices.shape[-1]).T)
np.add.at(zeros, idx, values.ravel())
return zeros
shape = (1000, 1000)
num_updates = 1000000
indices = np.random.randint(0, 1000, size=(num_updates, 2))
values = np.random.rand(num_updates)
# Benchmark Original
t0 = time.time()
original_scatter(indices, values, shape)
t1 = time.time()
print(f"Original Time: {t1 - t0:.4f}s")
# Benchmark Optimized
t2 = time.time()
optimized_scatter(indices, values, shape)
t3 = time.time()
print(f"Optimized Time: {t3 - t2:.4f}s")Observed Behavior
- Original Time: 3.8618s (Scalar Loop)
- Optimized Time: 0.0441s (Vectorized)
- Improvement: ~87.5x faster.
Root Cause
The current code in keras/src/backend/numpy/core.py iterates through each index manually:
for i in range(indices.shape[0]):
index = indices[i]
zeros[tuple(index)] += values[i]This bypasses NumPy's internal C-optimized loops. Vectorizing via np.add.at handles the broadcast and accumulation natively and avoids GIL-bound Python overhead.
Proposed Fix
I suggest replacing the manual loop with the following vectorized implementation in keras/src/backend/numpy/core.py:
def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
zeros = np.zeros(shape, dtype=values.dtype)
index_length = indices.shape[-1]
indices = np.reshape(indices, [-1, index_length])
values = values.ravel()
idx = tuple(indices.T)
np.add.at(zeros, idx, values)
return zerosSystem Information (Google Colab)
==============================
System Info
==============================
OS : Linux 6.6.105+
Python version : 3.12.12
Keras version : 3.10.0
Keras Backend : Numpy
==============================
GPU Info
==============================
GPU Model : Tesla T4
CUDA Available : Yes
CUDA Version : 12.8
==============================
Library Versions
==============================
jax 0.7.2
jax-cuda12-pjrt 0.7.2
jax-cuda12-plugin 0.7.2
jaxlib 0.7.2
numpy 2.0.2
tensorflow 2.19.0
tensorflow-datasets 4.9.9
tensorflow_decision_forests 1.12.0
tensorflow-hub 0.16.1
tensorflow-metadata 1.17.3
tensorflow-probability 0.25.0
tensorflow-text 2.19.0
torch 2.9.0+cu128
torchao 0.10.0
torchaudio 2.9.0+cu128
torchcodec 0.8.0+cu128
torchdata 0.11.0
torchsummary 1.5.1
torchtune 0.6.1
torchvision 0.24.0+cu128
Conclusion
This optimization significantly reduces latency for scatter-heavy operations in the NumPy backend. It maintains full compatibility with existing logic, including handling duplicate indices via cumulative addition.