Skip to content

keras.ops.associative_scan cause infinite recursion in TensorFlow static graph mode. #22058

@nexeora

Description

@nexeora

I found that keras.ops.associative_scan may cause infinite recursion under any input in TensorFlow static graph mode.

Here is a reproducible example that demonstrates this:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import tensorflow as tf

ascan = tf.function(keras.ops.associative_scan)
add = lambda x, y: x + y
arr = keras.ops.array([1, 2, 3])
print(ascan(add, arr, axis=0))

This results in a RecursionError: maximum recursion depth exceeded.

Here is another example which also causes infinite recursion:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras

class layer(keras.layers.Layer):
    def call(self, inputs):
        return keras.ops.associative_scan(lambda x, y: x + y, inputs, axis=1)

model = keras.models.Sequential([layer()])
model.compile(optimizer="adam",loss="mse",)
model.fit(
    x=keras.ops.array([[0, 1, 2]]),
    y=keras.ops.array([[3, 4, 5]]),
    epochs=1,
)

I believe the reason is that the keras.src.backend.tensorflow.core.associative_scan implementation is recursive and does not set a maximum recursion depth limit, while TensorFlow constructs a subgraph for each branch of tf.cond, leading to infinite recursion during graph construction.

The fix may require adjusting the API and adding a maximum recursion depth limit as a parameter.

However, one point that confuses me is that this function has existed for such a long time. If my inference is correct, any attempt to build and train a model with this function would trigger infinite recursion, yet no one has discovered it.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions