Skip to content

Commit 6114a21

Browse files
Force computation and data on CPU when in grain data pipeline. (#21553)
1 parent 0c6c363 commit 6114a21

File tree

6 files changed

+82
-3
lines changed

6 files changed

+82
-3
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,11 @@ def process_id():
201201
def _to_backend_device(device_name):
202202
if isinstance(device_name, jax.Device):
203203
return device_name
204-
device_type, device_id = device_name.split(":")
204+
device_name = str(device_name)
205+
if ":" not in device_name:
206+
device_type, device_id = device_name, 0
207+
else:
208+
device_type, device_id = device_name.split(":")
205209

206210
devices = jax.devices(backend=device_type)
207211
for device in devices:

keras/src/backend/torch/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def device_scope(device_name):
6262
current_device = _parse_device_input(device_name)
6363
global_state.set_global_attribute("torch_device", current_device)
6464
try:
65-
yield
65+
yield torch.device(current_device)
6666
finally:
6767
global_state.set_global_attribute("torch_device", previous_device)
6868

keras/src/layers/preprocessing/image_preprocessing/resizing_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import grain
12
import numpy as np
23
import pytest
34
from absl.testing import parameterized
@@ -158,6 +159,34 @@ def test_tf_data_compatibility(self):
158159
output = next(iter(ds)).numpy()
159160
self.assertEqual(tuple(output.shape), output_shape)
160161

162+
def test_grain_compatibility(self):
163+
if backend.config.image_data_format() == "channels_last":
164+
input_shape = (2, 10, 12, 3)
165+
output_shape = (2, 8, 9, 3)
166+
else:
167+
input_shape = (2, 3, 10, 12)
168+
output_shape = (2, 3, 8, 9)
169+
layer = layers.Resizing(8, 9)
170+
input_data = np.random.random(input_shape)
171+
ds = (
172+
grain.MapDataset.source(input_data)
173+
.to_iter_dataset()
174+
.batch(2)
175+
.map(layer)
176+
)
177+
output = next(iter(ds))
178+
output_np = backend.convert_to_numpy(output)
179+
180+
self.assertEqual(tuple(output_np.shape), output_shape)
181+
self.assertTrue(backend.is_tensor(output))
182+
# Ensure the device of the data is on CPU.
183+
if backend.backend() == "tensorflow":
184+
self.assertIn("CPU", str(output.device))
185+
elif backend.backend() == "jax":
186+
self.assertIn("CPU", str(output.device))
187+
elif backend.backend() == "torch":
188+
self.assertEqual("cpu", str(output.device))
189+
161190
@pytest.mark.skipif(
162191
backend.backend() != "tensorflow",
163192
reason="Sequential + tf.data only works with TF backend",

keras/src/layers/preprocessing/rescaling_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import grain
12
import numpy as np
23
import pytest
34
from tensorflow import data as tf_data
@@ -74,6 +75,21 @@ def test_tf_data_compatibility(self):
7475
ds = tf_data.Dataset.from_tensor_slices(x).batch(3).map(layer)
7576
next(iter(ds)).numpy()
7677

78+
def test_grain_compatibility(self):
79+
layer = layers.Rescaling(scale=1.0 / 255, offset=0.5)
80+
x = np.random.random((3, 10, 10, 3)) * 255
81+
ds = grain.MapDataset.source(x).to_iter_dataset().batch(3).map(layer)
82+
output = next(iter(ds))
83+
84+
self.assertTrue(backend.is_tensor(output))
85+
# Ensure the device of the data is on CPU.
86+
if backend.backend() == "tensorflow":
87+
self.assertIn("CPU", str(output.device))
88+
elif backend.backend() == "jax":
89+
self.assertIn("CPU", str(output.device))
90+
elif backend.backend() == "torch":
91+
self.assertEqual("cpu", str(output.device))
92+
7793
def test_rescaling_with_channels_first_and_vector_scale(self):
7894
config = backend.image_data_format()
7995
backend.set_image_data_format("channels_first")

keras/src/layers/preprocessing/tf_data_layer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,16 @@ def __call__(self, inputs, **kwargs):
4646
if switch_convert_input_args:
4747
self._convert_input_args = True
4848
return outputs
49-
return super().__call__(inputs, **kwargs)
49+
elif (
50+
not isinstance(sample_input, keras.KerasTensor)
51+
and backend_utils.in_grain_data_pipeline()
52+
):
53+
# We're in a Grain data pipeline. Force computation and data
54+
# placement to CPU.
55+
with keras.src.backend.device_scope("cpu"):
56+
return super().__call__(inputs, **kwargs)
57+
else:
58+
return super().__call__(inputs, **kwargs)
5059

5160
@tracking.no_automatic_dependency_tracking
5261
def _get_seed_generator(self, backend=None):

keras/src/utils/backend_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import importlib
3+
import inspect
34
import os
45
import sys
56

@@ -40,6 +41,26 @@ def __exit__(self, *args, **kwargs):
4041
)
4142

4243

44+
def in_grain_data_pipeline():
45+
if "grain" not in sys.modules:
46+
# Fast path to check if grain is not imported.
47+
return False
48+
49+
# We use a lightweight version of `inspect.stack` to detect execution within
50+
# grain.
51+
current_frame = inspect.currentframe()
52+
while current_frame:
53+
if (
54+
os.path.join("grain", "_src", "python", "dataset")
55+
in current_frame.f_code.co_filename
56+
or os.path.join("grain", "_src", "python", "data_loader")
57+
in current_frame.f_code.co_filename
58+
):
59+
return True
60+
current_frame = current_frame.f_back
61+
return False
62+
63+
4364
class DynamicBackend:
4465
"""A class that can be used to switch from one backend to another.
4566

0 commit comments

Comments
 (0)