Skip to content

Commit 2b6c800

Browse files
authored
test: add test for class weights (py_dataset adapter) (#20638)
* test: add test for class weights (py_dataset adapter) * "call _standardize_batch from enqueuer" m * add more tests, handle pytorch astype issue m * convert to numpy to ensure consistent handling of operations
1 parent d8afc05 commit 2b6c800

File tree

5 files changed

+183
-13
lines changed

5 files changed

+183
-13
lines changed

keras/src/trainers/data_adapters/data_adapter_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from keras.src import backend
4+
from keras.src import ops
45
from keras.src import tree
56
from keras.src.api_export import keras_export
67

@@ -115,15 +116,20 @@ def check_data_cardinality(data):
115116

116117

117118
def class_weight_to_sample_weights(y, class_weight):
118-
sample_weight = np.ones(shape=(y.shape[0],), dtype=backend.floatx())
119-
if len(y.shape) > 1:
120-
if y.shape[-1] != 1:
121-
y = np.argmax(y, axis=-1)
119+
# Convert to numpy to ensure consistent handling of operations
120+
# (e.g., np.round()) across frameworks like TensorFlow, JAX, and PyTorch
121+
122+
y_numpy = ops.convert_to_numpy(y)
123+
sample_weight = np.ones(shape=(y_numpy.shape[0],), dtype=backend.floatx())
124+
if len(y_numpy.shape) > 1:
125+
if y_numpy.shape[-1] != 1:
126+
y_numpy = np.argmax(y_numpy, axis=-1)
122127
else:
123-
y = np.squeeze(y, axis=-1)
124-
y = np.round(y).astype("int32")
125-
for i in range(y.shape[0]):
126-
sample_weight[i] = class_weight.get(int(y[i]), 1.0)
128+
y_numpy = np.squeeze(y_numpy, axis=-1)
129+
y_numpy = np.round(y_numpy).astype("int32")
130+
131+
for i in range(y_numpy.shape[0]):
132+
sample_weight[i] = class_weight.get(int(y_numpy[i]), 1.0)
127133
return sample_weight
128134

129135

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
import pytest
3+
from absl.testing import parameterized
4+
5+
from keras.src import backend
6+
from keras.src import testing
7+
from keras.src.trainers.data_adapters.data_adapter_utils import (
8+
class_weight_to_sample_weights,
9+
)
10+
11+
12+
class TestClassWeightToSampleWeights(testing.TestCase):
13+
@parameterized.named_parameters(
14+
[
15+
# Simple case, where y is flat
16+
(
17+
"simple_class_labels",
18+
np.array([0, 1, 0, 2]),
19+
{0: 1.0, 1: 2.0, 2: 3.0},
20+
np.array([1.0, 2.0, 1.0, 3.0]),
21+
),
22+
# Testing with one-hot encoded labels,
23+
# so basically the argmax statement
24+
(
25+
"one_hot_encoded_labels",
26+
np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]]),
27+
{0: 1.0, 1: 2.0, 2: 3.0},
28+
np.array([1.0, 2.0, 1.0, 3.0]),
29+
),
30+
# 3 is not mapped, so it's assigned the default weight (1)
31+
(
32+
"unmapped_class",
33+
np.array([0, 3, 0, 2]),
34+
{0: 1.0, 1: 2.0, 2: 3.0},
35+
np.array([1.0, 1.0, 1.0, 3.0]),
36+
),
37+
(
38+
"multi_dimensional_input",
39+
np.array([[0], [1], [0], [2]]),
40+
{0: 1.0, 1: 2.0, 2: 3.0},
41+
np.array([1.0, 2.0, 1.0, 3.0]),
42+
),
43+
(
44+
"all_unmapped",
45+
np.array([0, 1, 0, 2]),
46+
{},
47+
np.array([1.0, 1.0, 1.0, 1.0]),
48+
),
49+
]
50+
)
51+
def test_class_weight_to_sample_weights(self, y, class_weight, expected):
52+
self.assertAllClose(
53+
class_weight_to_sample_weights(y, class_weight), expected
54+
)
55+
56+
@pytest.mark.skipif(backend.backend() != "torch", reason="torch only")
57+
def test_class_weight_to_sample_weights_torch_specific(self):
58+
import torch
59+
60+
y = torch.from_numpy(np.array([0, 1, 0, 2]))
61+
self.assertAllClose(
62+
class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}),
63+
np.array([1.0, 2.0, 1.0, 3.0]),
64+
)
65+
y_one_hot = torch.from_numpy(
66+
np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])
67+
)
68+
self.assertAllClose(
69+
class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}),
70+
np.array([1.0, 2.0, 1.0, 3.0]),
71+
)
72+
73+
@pytest.mark.skipif(backend.backend() != "jax", reason="jax only")
74+
def test_class_weight_to_sample_weights_jax_specific(self):
75+
import jax
76+
77+
y = jax.numpy.asarray(np.array([0, 1, 0, 2]))
78+
self.assertAllClose(
79+
class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}),
80+
np.array([1.0, 2.0, 1.0, 3.0]),
81+
)
82+
y_one_hot = jax.numpy.asarray(
83+
np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])
84+
)
85+
self.assertAllClose(
86+
class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}),
87+
np.array([1.0, 2.0, 1.0, 3.0]),
88+
)
89+
90+
@pytest.mark.skipif(
91+
backend.backend() != "tensorflow", reason="tensorflow only"
92+
)
93+
def test_class_weight_to_sample_weights_tf_specific(self):
94+
import tensorflow as tf
95+
96+
y = tf.convert_to_tensor(np.array([0, 1, 0, 2]))
97+
self.assertAllClose(
98+
class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}),
99+
np.array([1.0, 2.0, 1.0, 3.0]),
100+
)
101+
y_one_hot = tf.convert_to_tensor(
102+
np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])
103+
)
104+
self.assertAllClose(
105+
class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}),
106+
np.array([1.0, 2.0, 1.0, 3.0]),
107+
)

keras/src/trainers/data_adapters/py_dataset_adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _standardize_batch(self, batch):
236236

237237
def _infinite_generator(self):
238238
for i in itertools.count():
239-
yield self.py_dataset[i]
239+
yield self._standardize_batch(self.py_dataset[i])
240240

241241
def _finite_generator(self):
242242
indices = range(self.py_dataset.num_batches)
@@ -245,18 +245,18 @@ def _finite_generator(self):
245245
random.shuffle(indices)
246246

247247
for i in indices:
248-
yield self.py_dataset[i]
248+
yield self._standardize_batch(self.py_dataset[i])
249249

250250
def _infinite_enqueuer_generator(self):
251251
self.enqueuer.start()
252252
for batch in self.enqueuer.get():
253-
yield batch
253+
yield self._standardize_batch(batch)
254254

255255
def _finite_enqueuer_generator(self):
256256
self.enqueuer.start()
257257
num_batches = self.py_dataset.num_batches
258258
for i, batch in enumerate(self.enqueuer.get()):
259-
yield batch
259+
yield self._standardize_batch(batch)
260260
if i >= num_batches - 1:
261261
self.enqueuer.stop()
262262
return

keras/src/trainers/data_adapters/py_dataset_adapter_test.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,41 @@ def test_basic_flow(
217217
else:
218218
self.assertAllClose(sample_order, expected_order)
219219

220-
# TODO: test class_weight
221220
# TODO: test sample weights
222221
# TODO: test inference mode (single output)
223222

223+
def test_class_weight(self):
224+
x = np.random.randint(1, 100, (4, 5))
225+
y = np.array([0, 1, 2, 1])
226+
class_w = {0: 2, 1: 1, 2: 3}
227+
py_dataset = ExamplePyDataset(x, y, batch_size=2)
228+
adapter = py_dataset_adapter.PyDatasetAdapter(
229+
py_dataset, shuffle=False, class_weight=class_w
230+
)
231+
if backend.backend() == "numpy":
232+
gen = adapter.get_numpy_iterator()
233+
elif backend.backend() == "tensorflow":
234+
gen = adapter.get_tf_dataset()
235+
elif backend.backend() == "jax":
236+
gen = adapter.get_jax_iterator()
237+
elif backend.backend() == "torch":
238+
gen = adapter.get_torch_dataloader()
239+
240+
for index, batch in enumerate(gen):
241+
# Batch is a tuple of (x, y, class_weight)
242+
self.assertLen(batch, 3)
243+
# Let's verify the data and class weights match for each element
244+
# of the batch (2 elements in each batch)
245+
for sub_elem in range(2):
246+
self.assertTrue(
247+
np.array_equal(batch[0][sub_elem], x[index * 2 + sub_elem])
248+
)
249+
self.assertEqual(batch[1][sub_elem], y[index * 2 + sub_elem])
250+
class_key = np.int32(batch[1][sub_elem])
251+
self.assertEqual(batch[2][sub_elem], class_w[class_key])
252+
253+
self.assertEqual(index, 1) # 2 batches
254+
224255
def test_speedup(self):
225256
x = np.random.random((40, 4))
226257
y = np.random.random((40, 2))

keras/src/trainers/trainer_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,17 +522,37 @@ def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch):
522522
"testcase_name": "py_dataset",
523523
"dataset_type": "py_dataset",
524524
},
525+
{
526+
"testcase_name": "py_dataset_cw",
527+
"dataset_type": "py_dataset",
528+
"fit_kwargs": {"class_weight": {0: 1, 1: 2}},
529+
},
525530
{
526531
"testcase_name": "py_dataset_infinite",
527532
"dataset_type": "py_dataset",
528533
"dataset_kwargs": {"infinite": True},
529534
"fit_kwargs": {"steps_per_epoch": 20},
530535
},
536+
{
537+
"testcase_name": "py_dataset_infinite_cw",
538+
"dataset_type": "py_dataset",
539+
"dataset_kwargs": {"infinite": True},
540+
"fit_kwargs": {
541+
"steps_per_epoch": 20,
542+
"class_weight": {0: 1, 1: 2},
543+
},
544+
},
531545
{
532546
"testcase_name": "py_dataset_multithreading",
533547
"dataset_type": "py_dataset",
534548
"dataset_kwargs": {"workers": 2},
535549
},
550+
{
551+
"testcase_name": "py_dataset_multithreading_cw",
552+
"dataset_type": "py_dataset",
553+
"dataset_kwargs": {"workers": 2},
554+
"fit_kwargs": {"class_weight": {0: 1, 1: 2}},
555+
},
536556
{
537557
"testcase_name": "py_dataset_multithreading_infinite",
538558
"dataset_type": "py_dataset",
@@ -544,6 +564,12 @@ def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch):
544564
"dataset_type": "py_dataset",
545565
"dataset_kwargs": {"workers": 2, "use_multiprocessing": True},
546566
},
567+
{
568+
"testcase_name": "py_dataset_multiprocessing_cw",
569+
"dataset_type": "py_dataset",
570+
"dataset_kwargs": {"workers": 2, "use_multiprocessing": True},
571+
"fit_kwargs": {"class_weight": {0: 1, 1: 2}},
572+
},
547573
{
548574
"testcase_name": "py_dataset_multiprocessing_infinite",
549575
"dataset_type": "py_dataset",

0 commit comments

Comments
 (0)