Skip to content

Commit 2f7476c

Browse files
committed
Add crossentropy metrics.
1 parent 6c2fc07 commit 2f7476c

File tree

4 files changed

+400
-9
lines changed

4 files changed

+400
-9
lines changed

keras_core/losses/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1347,7 +1347,7 @@ def categorical_crossentropy(
13471347
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
13481348
>>> loss = keras_core.losses.categorical_crossentropy(y_true, y_pred)
13491349
>>> assert loss.shape == (2,)
1350-
>>> loss.numpy()
1350+
>>> loss
13511351
array([0.0513, 2.303], dtype=float32)
13521352
"""
13531353
if isinstance(axis, bool):

keras_core/metrics/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88
from keras_core.metrics.hinge_metrics import Hinge
99
from keras_core.metrics.hinge_metrics import SquaredHinge
1010
from keras_core.metrics.metric import Metric
11+
from keras_core.metrics.probabilistic_metrics import BinaryCrossentropy
12+
from keras_core.metrics.probabilistic_metrics import CategoricalCrossentropy
1113
from keras_core.metrics.probabilistic_metrics import KLDivergence
1214
from keras_core.metrics.probabilistic_metrics import Poisson
15+
from keras_core.metrics.probabilistic_metrics import (
16+
SparseCategoricalCrossentropy,
17+
)
1318
from keras_core.metrics.reduction_metrics import Mean
1419
from keras_core.metrics.reduction_metrics import MeanMetricWrapper
1520
from keras_core.metrics.reduction_metrics import Sum
1621
from keras_core.metrics.regression_metrics import MeanSquaredError
1722
from keras_core.metrics.regression_metrics import mean_squared_error
1823
from keras_core.saving import serialization_lib
19-
from keras_core.utils import naming
2024

2125
ALL_OBJECTS = {
2226
Metric,
@@ -30,6 +34,9 @@
3034
CategoricalHinge,
3135
KLDivergence,
3236
Poisson,
37+
BinaryCrossentropy,
38+
CategoricalCrossentropy,
39+
SparseCategoricalCrossentropy,
3340
}
3441
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
3542
ALL_OBJECTS_DICT.update(

keras_core/metrics/probabilistic_metrics.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from keras_core.api_export import keras_core_export
2+
from keras_core.losses.losses import binary_crossentropy
3+
from keras_core.losses.losses import categorical_crossentropy
24
from keras_core.losses.losses import kl_divergence
35
from keras_core.losses.losses import poisson
6+
from keras_core.losses.losses import sparse_categorical_crossentropy
47
from keras_core.metrics import reduction_metrics
58

69

@@ -62,6 +65,8 @@ class Poisson(reduction_metrics.MeanMetricWrapper):
6265
name: (Optional) string name of the metric instance.
6366
dtype: (Optional) data type of the metric result.
6467
68+
Examples:
69+
6570
Standalone usage:
6671
6772
>>> m = keras_core.metrics.Poisson()
@@ -89,3 +94,244 @@ def __init__(self, name="poisson", dtype=None):
8994

9095
def get_config(self):
9196
return {"name": self.name, "dtype": self.dtype}
97+
98+
99+
@keras_core_export("keras_core.metrics.BinaryCrossentropy")
100+
class BinaryCrossentropy(reduction_metrics.MeanMetricWrapper):
101+
"""Computes the crossentropy metric between the labels and predictions.
102+
103+
This is the crossentropy metric class to be used when there are only two
104+
label classes (0 and 1).
105+
106+
Args:
107+
name: (Optional) string name of the metric instance.
108+
dtype: (Optional) data type of the metric result.
109+
from_logits: (Optional) Whether output is expected
110+
to be a logits tensor. By default, we consider
111+
that output encodes a probability distribution.
112+
label_smoothing: (Optional) Float in `[0, 1]`.
113+
When > 0, label values are smoothed,
114+
meaning the confidence on label values are relaxed.
115+
e.g. `label_smoothing=0.2` means that we will use
116+
a value of 0.1 for label "0" and 0.9 for label "1".
117+
118+
Examples:
119+
120+
Standalone usage:
121+
122+
>>> m = keras_core.metrics.BinaryCrossentropy()
123+
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
124+
>>> m.result()
125+
0.81492424
126+
127+
>>> m.reset_state()
128+
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
129+
... sample_weight=[1, 0])
130+
>>> m.result()
131+
0.9162905
132+
133+
Usage with `compile()` API:
134+
135+
```python
136+
model.compile(
137+
optimizer='sgd',
138+
loss='mse',
139+
metrics=[keras_core.metrics.BinaryCrossentropy()])
140+
```
141+
"""
142+
143+
def __init__(
144+
self,
145+
name="binary_crossentropy",
146+
dtype=None,
147+
from_logits=False,
148+
label_smoothing=0,
149+
):
150+
super().__init__(
151+
binary_crossentropy,
152+
name,
153+
dtype=dtype,
154+
from_logits=from_logits,
155+
label_smoothing=label_smoothing,
156+
)
157+
self.from_logits = from_logits
158+
self.label_smoothing = label_smoothing
159+
160+
def get_config(self):
161+
return {
162+
"name": self.name,
163+
"dtype": self.dtype,
164+
"from_logits": self.from_logits,
165+
"label_smoothing": self.label_smoothing,
166+
}
167+
168+
169+
@keras_core_export("keras_core.metrics.CategoricalCrossentropy")
170+
class CategoricalCrossentropy(reduction_metrics.MeanMetricWrapper):
171+
"""Computes the crossentropy metric between the labels and predictions.
172+
173+
This is the crossentropy metric class to be used when there are multiple
174+
label classes (2 or more). It assumes that labels are one-hot encoded,
175+
e.g., when labels values are `[2, 0, 1]`, then
176+
`y_true` is `[[0, 0, 1], [1, 0, 0], [0, 1, 0]]`.
177+
178+
Args:
179+
name: (Optional) string name of the metric instance.
180+
dtype: (Optional) data type of the metric result.
181+
from_logits: (Optional) Whether output is expected to be
182+
a logits tensor. By default, we consider that output
183+
encodes a probability distribution.
184+
label_smoothing: (Optional) Float in `[0, 1]`.
185+
When > 0, label values are smoothed, meaning the confidence
186+
on label values are relaxed. e.g. `label_smoothing=0.2` means
187+
that we will use a value of 0.1 for label
188+
"0" and 0.9 for label "1".
189+
axis: (Optional) Defaults to -1.
190+
The dimension along which entropy is computed.
191+
192+
Examples:
193+
194+
Standalone usage:
195+
196+
>>> # EPSILON = 1e-7, y = y_true, y` = y_pred
197+
>>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON)
198+
>>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
199+
>>> # xent = -sum(y * log(y'), axis = -1)
200+
>>> # = -((log 0.95), (log 0.1))
201+
>>> # = [0.051, 2.302]
202+
>>> # Reduced xent = (0.051 + 2.302) / 2
203+
>>> m = keras_core.metrics.CategoricalCrossentropy()
204+
>>> m.update_state([[0, 1, 0], [0, 0, 1]],
205+
... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
206+
>>> m.result()
207+
1.1769392
208+
209+
>>> m.reset_state()
210+
>>> m.update_state([[0, 1, 0], [0, 0, 1]],
211+
... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],
212+
... sample_weight=np.array([0.3, 0.7]))
213+
>>> m.result()
214+
1.6271976
215+
216+
Usage with `compile()` API:
217+
218+
```python
219+
model.compile(
220+
optimizer='sgd',
221+
loss='mse',
222+
metrics=[keras_core.metrics.CategoricalCrossentropy()])
223+
```
224+
"""
225+
226+
def __init__(
227+
self,
228+
name="categorical_crossentropy",
229+
dtype=None,
230+
from_logits=False,
231+
label_smoothing=0,
232+
axis=-1,
233+
):
234+
super().__init__(
235+
categorical_crossentropy,
236+
name,
237+
dtype=dtype,
238+
from_logits=from_logits,
239+
label_smoothing=label_smoothing,
240+
axis=axis,
241+
)
242+
self.from_logits = from_logits
243+
self.label_smoothing = label_smoothing
244+
self.axis = axis
245+
246+
def get_config(self):
247+
return {
248+
"name": self.name,
249+
"dtype": self.dtype,
250+
"from_logits": self.from_logits,
251+
"label_smoothing": self.label_smoothing,
252+
"axis": self.axis,
253+
}
254+
255+
256+
@keras_core_export("keras_core.metrics.SparseCategoricalCrossentropy")
257+
class SparseCategoricalCrossentropy(reduction_metrics.MeanMetricWrapper):
258+
"""Computes the crossentropy metric between the labels and predictions.
259+
260+
Use this crossentropy metric when there are two or more label classes.
261+
It expects labels to be provided as integers. If you want to provide labels
262+
that are one-hot encoded, please use the `CategoricalCrossentropy`
263+
metric instead.
264+
265+
There should be `num_classes` floating point values per feature for `y_pred`
266+
and a single floating point value per feature for `y_true`.
267+
268+
Args:
269+
name: (Optional) string name of the metric instance.
270+
dtype: (Optional) data type of the metric result.
271+
from_logits: (Optional) Whether output is expected
272+
to be a logits tensor. By default, we consider that output
273+
encodes a probability distribution.
274+
axis: (Optional) Defaults to -1.
275+
The dimension along which entropy is computed.
276+
277+
Examples:
278+
279+
Standalone usage:
280+
281+
>>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]]
282+
>>> # logits = log(y_pred)
283+
>>> # softmax = exp(logits) / sum(exp(logits), axis=-1)
284+
>>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
285+
>>> # xent = -sum(y * log(softmax), 1)
286+
>>> # log(softmax) = [[-2.9957, -0.0513, -16.1181],
287+
>>> # [-2.3026, -0.2231, -2.3026]]
288+
>>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]]
289+
>>> # xent = [0.0513, 2.3026]
290+
>>> # Reduced xent = (0.0513 + 2.3026) / 2
291+
>>> m = keras_core.metrics.SparseCategoricalCrossentropy()
292+
>>> m.update_state([1, 2],
293+
... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
294+
>>> m.result()
295+
1.1769392
296+
297+
>>> m.reset_state()
298+
>>> m.update_state([1, 2],
299+
... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],
300+
... sample_weight=np.array([0.3, 0.7]))
301+
>>> m.result()
302+
1.6271976
303+
304+
Usage with `compile()` API:
305+
306+
```python
307+
model.compile(
308+
optimizer='sgd',
309+
loss='mse',
310+
metrics=[keras_core.metrics.SparseCategoricalCrossentropy()])
311+
```
312+
"""
313+
314+
def __init__(
315+
self,
316+
name="sparse_categorical_crossentropy",
317+
dtype=None,
318+
from_logits=False,
319+
axis=-1,
320+
):
321+
super().__init__(
322+
sparse_categorical_crossentropy,
323+
name=name,
324+
dtype=dtype,
325+
from_logits=from_logits,
326+
axis=axis,
327+
)
328+
self.from_logits = from_logits
329+
self.axis = axis
330+
331+
def get_config(self):
332+
return {
333+
"name": self.name,
334+
"dtype": self.dtype,
335+
"from_logits": self.from_logits,
336+
"axis": self.axis,
337+
}

0 commit comments

Comments
 (0)