-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
I found that keras.ops.associative_scan may cause infinite recursion under any input in TensorFlow static graph mode.
Here is a reproducible example that demonstrates this:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import tensorflow as tf
ascan = tf.function(keras.ops.associative_scan)
add = lambda x, y: x + y
arr = keras.ops.array([1, 2, 3])
print(ascan(add, arr, axis=0))This results in a RecursionError: maximum recursion depth exceeded.
Here is another example which also causes infinite recursion:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
class layer(keras.layers.Layer):
def call(self, inputs):
return keras.ops.associative_scan(lambda x, y: x + y, inputs, axis=1)
model = keras.models.Sequential([layer()])
model.compile(optimizer="adam",loss="mse",)
model.fit(
x=keras.ops.array([[0, 1, 2]]),
y=keras.ops.array([[3, 4, 5]]),
epochs=1,
)I believe the reason is that the keras.src.backend.tensorflow.core.associative_scan implementation is recursive and does not set a maximum recursion depth limit, while TensorFlow constructs a subgraph for each branch of tf.cond, leading to infinite recursion during graph construction.
The fix may require adjusting the API and adding a maximum recursion depth limit as a parameter.
However, one point that confuses me is that this function has existed for such a long time. If my inference is correct, any attempt to build and train a model with this function would trigger infinite recursion, yet no one has discovered it.