diff --git a/keras/src/layers/preprocessing/discretization.py b/keras/src/layers/preprocessing/discretization.py index 8c771a64a05..3a4d4b8dfb3 100644 --- a/keras/src/layers/preprocessing/discretization.py +++ b/keras/src/layers/preprocessing/discretization.py @@ -95,9 +95,6 @@ def __init__( dtype=None, name=None, ): - if dtype is None: - dtype = "int64" if output_mode == "int" else backend.floatx() - super().__init__(name=name, dtype=dtype) if sparse and not backend.SUPPORTS_SPARSE_TENSORS: @@ -155,6 +152,13 @@ def __init__( def input_dtype(self): return backend.floatx() + @property + def compute_dtype(self): + if self.output_mode == "int": + return "int64" + else: + return backend.floatx() + def adapt(self, data, steps=None): """Computes bin boundaries from quantiles in a input dataset. @@ -213,7 +217,10 @@ def reset_state(self): self.summary = np.array([[], []], dtype="float32") def compute_output_spec(self, inputs): - return backend.KerasTensor(shape=inputs.shape, dtype=self.compute_dtype) + output_dtype = ( + "int64" if self.output_mode == "int" else self.compute_dtype + ) + return backend.KerasTensor(shape=inputs.shape, dtype=output_dtype) def load_own_variables(self, store): if len(store) == 1: @@ -230,11 +237,14 @@ def call(self, inputs): ) indices = self.backend.numpy.digitize(inputs, self.bin_boundaries) + output_dtype = ( + "int64" if self.output_mode == "int" else self.compute_dtype + ) return numerical_utils.encode_categorical_inputs( indices, output_mode=self.output_mode, depth=len(self.bin_boundaries) + 1, - dtype=self.compute_dtype, + dtype=output_dtype, sparse=self.sparse, backend_module=self.backend, ) diff --git a/keras/src/layers/preprocessing/discretization_test.py b/keras/src/layers/preprocessing/discretization_test.py index 500c6e9ca03..c96b003833a 100644 --- a/keras/src/layers/preprocessing/discretization_test.py +++ b/keras/src/layers/preprocessing/discretization_test.py @@ -205,3 +205,24 @@ def test_call_before_adapt_raises(self): layer = layers.Discretization(num_bins=3) with self.assertRaisesRegex(ValueError, "You need .* call .*adapt"): layer([[0.1, 0.8, 0.9]]) + + def test_model_call_vs_predict_consistency(self): + """Test that model(input) and model.predict(input) produce consistent outputs.""" # noqa: E501 + # Test with int output mode + layer = layers.Discretization( + bin_boundaries=[-0.5, 0, 0.1, 0.2, 3], + output_mode="int", + ) + x = np.array([[0.0, 0.15, 0.21, 0.3], [0.0, 0.17, 0.451, 7.8]]) + + # Create model + inputs = layers.Input(shape=(4,), dtype="float32") + outputs = layer(inputs) + model = models.Model(inputs=inputs, outputs=outputs) + + # Test both execution modes + model_call_output = model(x) + predict_output = model.predict(x) + + # Check consistency + self.assertAllClose(model_call_output, predict_output)