Skip to content

Passing a plain callable to add_weight(constraint=...) raises ValueError in Keras 3 (torch backend) #22221

@griffinstalha

Description

@griffinstalha

Bug description

When using Keras 3 with the PyTorch backend, passing a plain Python callable to the constraint argument of Layer.add_weight() raises a ValueError, stating that the constraint must be an instance of keras.constraints.Constraint.

Historically (and in many examples), constraints are simple callables of the form f(w) -> w'. However, with Keras 3.12.1 (torch backend), plain callables are rejected, even though a proper subclass of keras.constraints.Constraint works correctly.

This appears to be either:

  • A regression in callable constraint support, or
  • A backend-specific inconsistency (if other backends accept callables)
  • A documentation mismatch if callables are no longer supported but examples imply otherwise.

A control test confirms that the constraint mechanism itself works when using a proper Constraint subclass.

Reproduction script

import os
os.environ["KERAS_BACKEND"] = "torch"

import keras
from keras import ops
import torch

def masked_constraint(mask):
    # Plain callable constraint
    def _apply(w):
        return w * mask
    return _apply

class CustomLayer(keras.layers.Layer):
    def build(self, input_shape):
        mask = torch.ones((2, 2), dtype=torch.float32)
        c = masked_constraint(mask)

        # This line raises ValueError
        self.w = self.add_weight(
            name="w",
            shape=(2, 2),
            initializer="ones",
            trainable=True,
            constraint=c,  # <-- plain callable
        )

    def call(self, inputs):
        return inputs

x = ops.ones((1, 2), dtype="float32")
layer = CustomLayer()
_ = layer(x)

Commands
KERAS_BACKEND=torch python keras_testcase.py

Actual behavior

The script raises a ValueError:
Invalid value for attribute constraint. Expected an instance of keras.constraints.Constraint, or None. Received: constraint=<function ...>

Traceback (full)

`ValueError`: Invalid value for attribute `constraint`. Expected an instance of `keras.constraints.Constraint`, or `None`. 
Received: constraint=<function main.<locals>.masked_constraint.<locals>._apply at 0x75e5ccdf13f0>

Expected behavior

One of the following:

  • Plain callables should be accepted as valid constraints (if intended behavior),
    OR
  • The documentation should clearly state that only keras.constraints.Constraint subclasses are supported and plain callables are not allowed.

If this is an intentional API restriction in Keras 3, clarification in docs would help prevent confusion.

Environment

- Python: 3.10.19
- Keras: 3.12.1
- PyTorch: 2.10.0+cu128
- CUDA available: True
- Backend: torch

GPU / Driver

- GPU: NVIDIA RTX 3090 (x4 system)
- CUDA Runtime: 12.8 (from PyTorch build)
- torch.cuda.is_available(): True

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions