Skip to content

Commit a1bcf94

Browse files
Refactor TFDataLayer to be more generic for Grain. (#21598)
1 parent 67bcd88 commit a1bcf94

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+394
-163
lines changed

keras/src/backend/torch/core.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -191,21 +191,18 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
191191
raise ValueError("`sparse=True` is not supported with torch backend")
192192
if ragged:
193193
raise ValueError("`ragged=True` is not supported with torch backend")
194-
if isinstance(x, Variable):
195-
if dtype is None:
196-
return x.value
197-
x = x.value
198-
return x.to(to_torch_dtype(dtype))
199-
if is_tensor(x):
194+
if isinstance(x, Variable) or is_tensor(x):
195+
if isinstance(x, Variable):
196+
x = x.value
200197
device = get_device()
201198
if x.device != device:
202199
if x.is_meta:
203200
x = torch.empty_like(x, device=device)
204201
else:
205202
x = x.to(device)
206-
if dtype is None:
207-
return x
208-
return x.to(to_torch_dtype(dtype))
203+
if dtype is not None:
204+
x = x.to(to_torch_dtype(dtype))
205+
return x
209206
if dtype is None:
210207
if isinstance(x, bool):
211208
return torch.as_tensor(x, dtype=torch.bool, device=get_device())

keras/src/layers/preprocessing/category_encoding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from keras.src.api_export import keras_export
22
from keras.src.backend import KerasTensor
3-
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
3+
from keras.src.layers.preprocessing.data_layer import DataLayer
44
from keras.src.utils import backend_utils
55
from keras.src.utils import numerical_utils
66

77

88
@keras_export("keras.layers.CategoryEncoding")
9-
class CategoryEncoding(TFDataLayer):
9+
class CategoryEncoding(DataLayer):
1010
"""A preprocessing layer which encodes integer features.
1111
1212
This layer provides options for condensing data into a categorical encoding
@@ -15,7 +15,7 @@ class CategoryEncoding(TFDataLayer):
1515
inputs. For integer inputs where the total number of tokens is not known,
1616
use `keras.layers.IntegerLookup` instead.
1717
18-
**Note:** This layer is safe to use inside a `tf.data` pipeline
18+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
1919
(independently of which backend you're using).
2020
2121
Examples:
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import keras.src.backend
2+
from keras.src import tree
3+
from keras.src.layers.layer import Layer
4+
from keras.src.random.seed_generator import SeedGenerator
5+
from keras.src.utils import backend_utils
6+
from keras.src.utils import jax_utils
7+
from keras.src.utils import tracking
8+
9+
10+
class DataLayer(Layer):
11+
"""Layer designed for safe use in `tf.data` or `grain` pipeline.
12+
13+
This layer overrides the `__call__` method to ensure that the correct
14+
backend is used and that computation is performed on the CPU.
15+
16+
The `call()` method in subclasses should use `self.backend` ops. If
17+
randomness is needed, define both `seed` and `generator` in `__init__` and
18+
retrieve the running seed using `self._get_seed_generator()`. If the layer
19+
has weights in `__init__` or `build()`, use `convert_weight()` to ensure
20+
they are in the correct backend.
21+
22+
**Note:** This layer and its subclasses only support a single input tensor.
23+
24+
Examples:
25+
26+
**Custom `DataLayer` subclass:**
27+
28+
```python
29+
from keras.src.layers.preprocessing.data_layer import DataLayer
30+
from keras.src.random import SeedGenerator
31+
32+
33+
class BiasedRandomRGBToHSVLayer(DataLayer):
34+
def __init__(self, seed=None, **kwargs):
35+
super().__init__(**kwargs)
36+
self.probability_bias = ops.convert_to_tensor(0.01)
37+
self.seed = seed
38+
self.generator = SeedGenerator(seed)
39+
40+
def call(self, inputs):
41+
images_shape = self.backend.shape(inputs)
42+
batch_size = 1 if len(images_shape) == 3 else images_shape[0]
43+
seed = self._get_seed_generator(self.backend._backend)
44+
45+
probability = self.backend.random.uniform(
46+
shape=(batch_size,),
47+
minval=0.0,
48+
maxval=1.0,
49+
seed=seed,
50+
)
51+
probability = self.backend.numpy.add(
52+
probability, self.convert_weight(self.probability_bias)
53+
)
54+
hsv_images = self.backend.image.rgb_to_hsv(inputs)
55+
return self.backend.numpy.where(
56+
probability[:, None, None, None] > 0.5,
57+
hsv_images,
58+
inputs,
59+
)
60+
61+
def compute_output_shape(self, input_shape):
62+
return input_shape
63+
```
64+
65+
**Using as a regular Keras layer:**
66+
67+
```python
68+
import numpy as np
69+
70+
x = np.random.uniform(size=(1, 16, 16, 3)).astype("float32")
71+
print(BiasedRandomRGBToHSVLayer()(x).shape) # (1, 16, 16, 3)
72+
```
73+
74+
**Using in a `tf.data` pipeline:**
75+
76+
```python
77+
import tensorflow as tf
78+
79+
tf_ds = tf.data.Dataset.from_tensors(x)
80+
tf_ds = tf_ds.map(BiasedRandomRGBToHSVLayer())
81+
print([x.shape for x in tf_ds]) # [(1, 16, 16, 3)]
82+
```
83+
84+
**Using in a `grain` pipeline:**
85+
86+
```python
87+
import grain
88+
89+
grain_ds = grain.MapDataset.source([x])
90+
grain_ds = grain_ds.map(BiasedRandomRGBToHSVLayer())
91+
print([x.shape for x in grain_ds]) # [(1, 16, 16, 3)]
92+
"""
93+
94+
def __init__(self, **kwargs):
95+
super().__init__(**kwargs)
96+
self.backend = backend_utils.DynamicBackend()
97+
self._allow_non_tensor_positional_args = True
98+
99+
def __call__(self, inputs, **kwargs):
100+
sample_input = tree.flatten(inputs)[0]
101+
if (
102+
not isinstance(sample_input, keras.KerasTensor)
103+
and backend_utils.in_tf_graph()
104+
and not jax_utils.is_in_jax_tracing_scope(sample_input)
105+
):
106+
# We're in a TF graph, e.g. a tf.data pipeline.
107+
self.backend.set_backend("tensorflow")
108+
inputs = tree.map_structure(
109+
lambda x: self.backend.convert_to_tensor(
110+
x, dtype=self.compute_dtype
111+
),
112+
inputs,
113+
)
114+
switch_convert_input_args = False
115+
if self._convert_input_args:
116+
self._convert_input_args = False
117+
switch_convert_input_args = True
118+
try:
119+
outputs = super().__call__(inputs, **kwargs)
120+
finally:
121+
self.backend.reset()
122+
if switch_convert_input_args:
123+
self._convert_input_args = True
124+
return outputs
125+
elif (
126+
not isinstance(sample_input, keras.KerasTensor)
127+
and backend_utils.in_grain_data_pipeline()
128+
):
129+
# We're in a Grain data pipeline. Force computation and data
130+
# placement to CPU.
131+
with keras.src.backend.device_scope("cpu"):
132+
return super().__call__(inputs, **kwargs)
133+
else:
134+
return super().__call__(inputs, **kwargs)
135+
136+
@tracking.no_automatic_dependency_tracking
137+
def _get_seed_generator(self, backend=None):
138+
if not hasattr(self, "seed") or not hasattr(self, "generator"):
139+
raise ValueError(
140+
"The `seed` and `generator` variable must be set in the "
141+
"`__init__` method before calling `_get_seed_generator()`."
142+
)
143+
if backend is None or backend == keras.backend.backend():
144+
return self.generator
145+
if not hasattr(self, "_backend_generators"):
146+
self._backend_generators = {}
147+
if backend in self._backend_generators:
148+
return self._backend_generators[backend]
149+
seed_generator = SeedGenerator(self.seed, backend=self.backend)
150+
self._backend_generators[backend] = seed_generator
151+
return seed_generator
152+
153+
def convert_weight(self, weight):
154+
"""Convert the weight if it is from the a different backend."""
155+
if self.backend.name == keras.backend.backend():
156+
return weight
157+
else:
158+
weight = keras.ops.convert_to_numpy(weight)
159+
return self.backend.convert_to_tensor(weight)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import grain
2+
import numpy as np
3+
import pytest
4+
from tensorflow import data as tf_data
5+
6+
from keras.src import backend
7+
from keras.src import testing
8+
from keras.src.layers.preprocessing.data_layer import DataLayer
9+
from keras.src.random import SeedGenerator
10+
11+
12+
class RandomRGBToHSVLayer(DataLayer):
13+
def __init__(self, data_format=None, seed=None, **kwargs):
14+
super().__init__(**kwargs)
15+
self.data_format = backend.standardize_data_format(data_format)
16+
self.seed = seed
17+
self.generator = SeedGenerator(seed)
18+
19+
def call(self, inputs):
20+
images_shape = self.backend.shape(inputs)
21+
batch_size = 1 if len(images_shape) == 3 else images_shape[0]
22+
seed = self._get_seed_generator(self.backend._backend)
23+
24+
probability = self.backend.random.uniform(
25+
shape=(batch_size,),
26+
minval=0.0,
27+
maxval=1.0,
28+
seed=seed,
29+
)
30+
hsv_images = self.backend.image.rgb_to_hsv(
31+
inputs, data_format=self.data_format
32+
)
33+
return self.backend.numpy.where(
34+
probability[:, None, None, None] > 0.5, hsv_images, inputs
35+
)
36+
37+
def compute_output_shape(self, input_shape):
38+
return input_shape
39+
40+
41+
class DataLayerTest(testing.TestCase):
42+
@pytest.mark.requires_trainable_backend
43+
def test_layer(self):
44+
self.run_layer_test(
45+
RandomRGBToHSVLayer,
46+
init_kwargs={
47+
"seed": 1337,
48+
"data_format": "channels_last",
49+
},
50+
input_shape=(1, 2, 2, 3),
51+
supports_masking=False,
52+
expected_output_shape=(1, 2, 2, 3),
53+
)
54+
55+
self.run_layer_test(
56+
RandomRGBToHSVLayer,
57+
init_kwargs={
58+
"seed": 1337,
59+
"data_format": "channels_first",
60+
},
61+
input_shape=(1, 3, 2, 2),
62+
supports_masking=False,
63+
expected_output_shape=(1, 3, 2, 2),
64+
)
65+
66+
def test_tf_data_compatibility(self):
67+
data_format = backend.config.image_data_format()
68+
if data_format == "channels_last":
69+
input_data = np.random.random((2, 8, 8, 3)).astype("float32")
70+
else:
71+
input_data = np.random.random((2, 3, 8, 8)).astype("float32")
72+
layer = RandomRGBToHSVLayer(data_format=data_format, seed=1337)
73+
74+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
75+
for output in ds.take(1):
76+
self.assertDType(output, "float32")
77+
self.assertEqual(list(output.shape), list(input_data.shape))
78+
79+
def test_grain_compatibility(self):
80+
data_format = backend.config.image_data_format()
81+
if data_format == "channels_last":
82+
input_data = np.random.random((2, 8, 8, 3)).astype("float32")
83+
else:
84+
input_data = np.random.random((2, 3, 8, 8)).astype("float32")
85+
layer = RandomRGBToHSVLayer(data_format=data_format, seed=1337)
86+
87+
ds = grain.MapDataset.source(input_data).batch(2).map(layer)
88+
for output in ds[:1]:
89+
self.assertDType(output, "float32")
90+
self.assertEqual(list(output.shape), list(input_data.shape))

keras/src/layers/preprocessing/discretization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22

33
from keras.src import backend
44
from keras.src.api_export import keras_export
5-
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
5+
from keras.src.layers.preprocessing.data_layer import DataLayer
66
from keras.src.utils import argument_validation
77
from keras.src.utils import numerical_utils
88
from keras.src.utils.module_utils import tensorflow as tf
99

1010

1111
@keras_export("keras.layers.Discretization")
12-
class Discretization(TFDataLayer):
12+
class Discretization(DataLayer):
1313
"""A preprocessing layer which buckets continuous features by ranges.
1414
1515
This layer will place each element of its input data into one of several
1616
contiguous ranges and output an integer index indicating which range each
1717
element was placed in.
1818
19-
**Note:** This layer is safe to use inside a `tf.data` pipeline
19+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
2020
(independently of which backend you're using).
2121
2222
Input shape:

keras/src/layers/preprocessing/feature_space.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from keras.src import tree
44
from keras.src.api_export import keras_export
55
from keras.src.layers.layer import Layer
6-
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
6+
from keras.src.layers.preprocessing.data_layer import DataLayer
77
from keras.src.saving import saving_lib
88
from keras.src.saving import serialization_lib
99
from keras.src.saving.keras_saveable import KerasSaveable
@@ -723,7 +723,7 @@ def __call__(self, data):
723723
data[name] = tf.expand_dims(x, -1)
724724

725725
with backend_utils.TFGraphScope():
726-
# This scope is to make sure that inner TFDataLayers
726+
# This scope is to make sure that inner DataLayers
727727
# will not convert outputs back to backend-native --
728728
# they should be TF tensors throughout
729729
preprocessed_data = self._preprocess_features(data)
@@ -808,7 +808,7 @@ def load_own_variables(self, store):
808808
return
809809

810810

811-
class TFDConcat(TFDataLayer):
811+
class TFDConcat(DataLayer):
812812
def __init__(self, axis, **kwargs):
813813
super().__init__(**kwargs)
814814
self.axis = axis
@@ -817,6 +817,6 @@ def call(self, xs):
817817
return self.backend.numpy.concatenate(xs, axis=self.axis)
818818

819819

820-
class TFDIdentity(TFDataLayer):
820+
class TFDIdentity(DataLayer):
821821
def call(self, x):
822822
return x

keras/src/layers/preprocessing/image_preprocessing/aug_mix.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ class AugMix(BaseImagePreprocessingLayer):
4343
in num_chains different ways, with each chain consisting of
4444
chain_depth augmentations.
4545
46+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
47+
(independently of which backend you're using).
48+
49+
References:
50+
- [AugMix paper](https://arxiv.org/pdf/1912.02781)
51+
- [Official Code](https://github.com/google-research/augmix)
52+
4653
Args:
4754
value_range: the range of values the incoming images will have.
4855
Represented as a two number tuple written (low, high).
@@ -64,10 +71,6 @@ class AugMix(BaseImagePreprocessingLayer):
6471
interpolation: The interpolation method to use for resizing operations.
6572
Options include `"nearest"`, `"bilinear"`. Default is `"bilinear"`.
6673
seed: Integer. Used to create a random seed.
67-
68-
References:
69-
- [AugMix paper](https://arxiv.org/pdf/1912.02781)
70-
- [Official Code](https://github.com/google-research/augmix)
7174
"""
7275

7376
_USE_BASE_FACTOR = False

keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class AutoContrast(BaseImagePreprocessingLayer):
1717
1818
This layer is active at both training and inference time.
1919
20+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
21+
(independently of which backend you're using).
22+
2023
Args:
2124
value_range: Range of values the incoming images will have.
2225
Represented as a two number tuple written `(low, high)`.

0 commit comments

Comments
 (0)