-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
Bug description
When using Keras 3 with the PyTorch backend, passing a plain Python callable to the constraint argument of Layer.add_weight() raises a ValueError, stating that the constraint must be an instance of keras.constraints.Constraint.
Historically (and in many examples), constraints are simple callables of the form f(w) -> w'. However, with Keras 3.12.1 (torch backend), plain callables are rejected, even though a proper subclass of keras.constraints.Constraint works correctly.
This appears to be either:
- A regression in callable constraint support, or
- A backend-specific inconsistency (if other backends accept callables)
- A documentation mismatch if callables are no longer supported but examples imply otherwise.
A control test confirms that the constraint mechanism itself works when using a proper Constraint subclass.
Reproduction script
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
from keras import ops
import torch
def masked_constraint(mask):
# Plain callable constraint
def _apply(w):
return w * mask
return _apply
class CustomLayer(keras.layers.Layer):
def build(self, input_shape):
mask = torch.ones((2, 2), dtype=torch.float32)
c = masked_constraint(mask)
# This line raises ValueError
self.w = self.add_weight(
name="w",
shape=(2, 2),
initializer="ones",
trainable=True,
constraint=c, # <-- plain callable
)
def call(self, inputs):
return inputs
x = ops.ones((1, 2), dtype="float32")
layer = CustomLayer()
_ = layer(x)
Commands
KERAS_BACKEND=torch python keras_testcase.py
Actual behavior
The script raises a ValueError:
Invalid value for attribute constraint. Expected an instance of keras.constraints.Constraint, or None. Received: constraint=<function ...>
Traceback (full)
`ValueError`: Invalid value for attribute `constraint`. Expected an instance of `keras.constraints.Constraint`, or `None`.
Received: constraint=<function main.<locals>.masked_constraint.<locals>._apply at 0x75e5ccdf13f0>
Expected behavior
One of the following:
- Plain callables should be accepted as valid constraints (if intended behavior),
OR - The documentation should clearly state that only
keras.constraints.Constraintsubclasses are supported and plain callables are not allowed.
If this is an intentional API restriction in Keras 3, clarification in docs would help prevent confusion.
Environment
- Python: 3.10.19
- Keras: 3.12.1
- PyTorch: 2.10.0+cu128
- CUDA available: True
- Backend: torch
GPU / Driver
- GPU: NVIDIA RTX 3090 (x4 system)
- CUDA Runtime: 12.8 (from PyTorch build)
- torch.cuda.is_available(): True