Skip to content

[Performance] Vectorize scatter operation in NumPy backend #22208

@amadhan882

Description

@amadhan882

Description

In the NumPy backend, the scatter implementation currently uses a pure Python for loop to update values at target indices. This introduces a significant performance bottleneck when the number of updates is large (e.g., 10^6 or more).

By replacing the scalar loop with NumPy's native np.add.at logic, we can achieve a ~87x speedup, bringing the NumPy backend's efficiency closer to other tensor backends for large-scale operations.

Minimal Reproduction & Benchmark

import os
import time
import numpy as np
from keras.src.backend.numpy.core import scatter as original_scatter

os.environ["KERAS_BACKEND"] = "numpy"

# Proposed Fix
def optimized_scatter(indices, values, shape):
    indices = np.array(indices)
    values = np.array(values)
    zeros = np.zeros(shape, dtype=values.dtype)
    idx = tuple(indices.reshape(-1, indices.shape[-1]).T)
    np.add.at(zeros, idx, values.ravel())
    return zeros

shape = (1000, 1000)
num_updates = 1000000
indices = np.random.randint(0, 1000, size=(num_updates, 2))
values = np.random.rand(num_updates)

# Benchmark Original
t0 = time.time()
original_scatter(indices, values, shape)
t1 = time.time()
print(f"Original Time: {t1 - t0:.4f}s")

# Benchmark Optimized
t2 = time.time()
optimized_scatter(indices, values, shape)
t3 = time.time()
print(f"Optimized Time: {t3 - t2:.4f}s")

Observed Behavior

  • Original Time: 3.8618s (Scalar Loop)
  • Optimized Time: 0.0441s (Vectorized)
  • Improvement: ~87.5x faster.

Root Cause

The current code in keras/src/backend/numpy/core.py iterates through each index manually:

for i in range(indices.shape[0]):
    index = indices[i]
    zeros[tuple(index)] += values[i]

This bypasses NumPy's internal C-optimized loops. Vectorizing via np.add.at handles the broadcast and accumulation natively and avoids GIL-bound Python overhead.

Proposed Fix

I suggest replacing the manual loop with the following vectorized implementation in keras/src/backend/numpy/core.py:

def scatter(indices, values, shape):
    indices = convert_to_tensor(indices)
    values = convert_to_tensor(values)
    zeros = np.zeros(shape, dtype=values.dtype)

    index_length = indices.shape[-1]
    indices = np.reshape(indices, [-1, index_length])
    values = values.ravel()

    idx = tuple(indices.T)
    np.add.at(zeros, idx, values)
    return zeros

System Information (Google Colab)

==============================
      System Info
==============================
OS              : Linux 6.6.105+
Python version  : 3.12.12
Keras version   : 3.10.0
Keras Backend   : Numpy

==============================
      GPU Info
==============================
GPU Model       : Tesla T4
CUDA Available  : Yes
CUDA Version    : 12.8

==============================
    Library Versions
==============================
jax                                      0.7.2
jax-cuda12-pjrt                          0.7.2
jax-cuda12-plugin                        0.7.2
jaxlib                                   0.7.2
numpy                                    2.0.2
tensorflow                               2.19.0
tensorflow-datasets                      4.9.9
tensorflow_decision_forests              1.12.0
tensorflow-hub                           0.16.1
tensorflow-metadata                      1.17.3
tensorflow-probability                   0.25.0
tensorflow-text                          2.19.0
torch                                    2.9.0+cu128
torchao                                  0.10.0
torchaudio                               2.9.0+cu128
torchcodec                               0.8.0+cu128
torchdata                                0.11.0
torchsummary                             1.5.1
torchtune                                0.6.1
torchvision                              0.24.0+cu128

Conclusion

This optimization significantly reduces latency for scatter-heavy operations in the NumPy backend. It maintains full compatibility with existing logic, including handling duplicate indices via cumulative addition.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions