Skip to content

Commit b12f0f4

Browse files
rryancopybara-github
authored andcommitted
Fall back to np.stack if *any* array is not an ndarray, not just the first.
PiperOrigin-RevId: 872706726
1 parent 85eabc2 commit b12f0f4

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

grain/_src/python/dataset/transformations/batch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,8 @@ def __call__(self, values: Sequence[T]) -> T:
8686
def _batch_fn(*xs: Sequence[T]) -> T:
8787
# If the thread pool is not available or the elements are not NumPy
8888
# arrays, fall back to the standard serial `np.stack` operation.
89-
if (self._parallel_batch_executor is None) or not isinstance(
90-
xs[0], np.ndarray
91-
):
89+
all_ndarray = all(isinstance(x, np.ndarray) for x in xs)
90+
if (self._parallel_batch_executor is None) or not all_ndarray:
9291
return np.stack(xs)
9392
xs = cast(Sequence[np.ndarray], xs)
9493
# Fall back to the standard serial `np.stack` operation if the size of

grain/_src/python/dataset/transformations/batch_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ def test_batch_single_value_parallel_batch_enabled_success(self):
6969
batched_values = make_batch_parallel(values)
7070
self.assertEqual(batched_values.shape, (1, 3))
7171

72+
def test_batch_non_numpy_values(self):
73+
values = [np.asarray([1, 2, 3]), [4, 5, 6]]
74+
make_batch_parallel = batch._MakeBatchParallel()
75+
batched_values = make_batch_parallel(values)
76+
self.assertIsInstance(batched_values, np.ndarray)
77+
self.assertEqual(batched_values.shape, (2, 3))
78+
7279
def test_batch_two_values_success(self):
7380
values = [np.asarray([1, 2, 3]), np.asarray([4, 5, 6])]
7481
batched_values = batch.make_batch(values)

0 commit comments

Comments
 (0)