Skip to content

Fix Discretization layer graph mode bug #21514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
7 changes: 5 additions & 2 deletions keras/src/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
name=None,
):
if dtype is None:
dtype = "int64" if output_mode == "int" else backend.floatx()
dtype = "float32"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should just remove the whole if dtype is None block, the base layer class will handle None properly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that! It results in the unit tests failing

FAILED keras/src/layers/preprocessing/discretization_test.py::DiscretizationTest::test_discretization_basics - AssertionError: expected output dtype int64, got int32:
- int32
+ int64

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it fails the same way with or without on JAX.

The thing is, this class has this:

    @property
    def input_dtype(self):
        return backend.floatx()

So I don't understand why the inputs were cast to ints.

Maybe you can try to override @property ... compute_dtype?


super().__init__(name=name, dtype=dtype)

Expand Down Expand Up @@ -213,7 +213,8 @@ def reset_state(self):
self.summary = np.array([[], []], dtype="float32")

def compute_output_spec(self, inputs):
return backend.KerasTensor(shape=inputs.shape, dtype=self.compute_dtype)
output_dtype = "int64" if self.output_mode == "int" else self.compute_dtype
return backend.KerasTensor(shape=inputs.shape, dtype=output_dtype)

def load_own_variables(self, store):
if len(store) == 1:
Expand All @@ -229,7 +230,9 @@ def call(self, inputs):
"start using the `Discretization` layer."
)

# Use the backend's digitize function for all backends
indices = self.backend.numpy.digitize(inputs, self.bin_boundaries)

return numerical_utils.encode_categorical_inputs(
indices,
output_mode=self.output_mode,
Expand Down
29 changes: 29 additions & 0 deletions keras/src/layers/preprocessing/discretization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,32 @@ def test_call_before_adapt_raises(self):
layer = layers.Discretization(num_bins=3)
with self.assertRaisesRegex(ValueError, "You need .* call .*adapt"):
layer([[0.1, 0.8, 0.9]])


def test_discretization_eager_vs_graph():
import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf

import keras

layer = keras.layers.Discretization(
bin_boundaries=[-0.5, 0, 0.1, 0.2, 3],
name="bucket",
output_mode="int",
)

x = tf.constant([[0.0, 0.15, 0.21, 0.3], [0.0, 0.17, 0.451, 7.8]])
inputs = keras.layers.Input(name="inp", dtype="float32", shape=(4,))
model_output = layer(inputs)
model = keras.models.Model(inputs=[inputs], outputs=[model_output])

print("Eager mode (layer(x)):")
print(layer(x).numpy())

print("Model call (model(x)):")
print(model(x).numpy())

print("Model predict (model.predict(x)):")
print(model.predict(x))
Loading