Skip to content

Commit 9a3e173

Browse files
authored
Fix PyDatasetAdapterTest::test_class_weight test with Torch on GPU. (#20665)
The test was failing because arrays on device and on cpu were compared.
1 parent e3cf043 commit 9a3e173

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

keras/src/trainers/data_adapters/py_dataset_adapter_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,11 @@ def test_class_weight(self):
240240
for index, batch in enumerate(gen):
241241
# Batch is a tuple of (x, y, class_weight)
242242
self.assertLen(batch, 3)
243+
batch = [backend.convert_to_numpy(x) for x in batch]
243244
# Let's verify the data and class weights match for each element
244245
# of the batch (2 elements in each batch)
245246
for sub_elem in range(2):
246-
self.assertTrue(
247-
np.array_equal(batch[0][sub_elem], x[index * 2 + sub_elem])
248-
)
247+
self.assertAllEqual(batch[0][sub_elem], x[index * 2 + sub_elem])
249248
self.assertEqual(batch[1][sub_elem], y[index * 2 + sub_elem])
250249
class_key = np.int32(batch[1][sub_elem])
251250
self.assertEqual(batch[2][sub_elem], class_w[class_key])

0 commit comments

Comments
 (0)