Skip to content

Commit 672c731

Browse files
Add Grain integration. (#21494)
* Add Grain integration. * Fix CI and resolve comments from gemini-code-assist. * Add ragged attrs to `KerasTensor` and update `data_adapter_utils.py`. * Simplify the impl for getting `batch_size` and `output_signature`. * Add `builtin_prefetch` to `TFDatasetAdapter` and `TorchDataLoaderAdapter`. * Use `.map` instead of the internal grain functions.
1 parent f2b8b92 commit 672c731

13 files changed

+678
-19
lines changed

keras/src/backend/common/keras_tensor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,17 @@ def __init__(
3535
ragged=False,
3636
record_history=True,
3737
name=None,
38+
**kwargs,
3839
):
3940
from keras.src import backend
4041

42+
ragged_rank = kwargs.pop("ragged_rank", None)
43+
row_splits_dtype = kwargs.pop("row_splits_dtype", None)
44+
if kwargs:
45+
raise TypeError(
46+
f"Unexpected keyword arguments: {', '.join(kwargs.keys())}"
47+
)
48+
4149
self._shape = backend.standardize_shape(shape)
4250
self._dtype = backend.standardize_dtype(dtype)
4351
self._sparse = bool(sparse)
@@ -47,6 +55,14 @@ def __init__(
4755
"KerasTensor cannot have `sparse=True` and `ragged=True` at "
4856
"the same time."
4957
)
58+
self._ragged_rank = (
59+
int(ragged_rank) if ragged_rank is not None else None
60+
)
61+
self._row_splits_dtype = (
62+
backend.standardize_dtype(row_splits_dtype)
63+
if row_splits_dtype is not None
64+
else None
65+
)
5066
self.name = name or auto_name(self.__class__.__name__)
5167
self.record_history = record_history
5268

@@ -83,6 +99,28 @@ def sparse(self, value):
8399
"create a new instance of KerasTensor for this."
84100
)
85101

102+
@property
103+
def ragged_rank(self):
104+
return self._ragged_rank
105+
106+
@ragged_rank.setter
107+
def ragged_rank(self, value):
108+
raise AttributeError(
109+
"The `ragged_rank` attribute of KerasTensor is immutable. One "
110+
"should create a new instance of KerasTensor for this."
111+
)
112+
113+
@property
114+
def row_splits_dtype(self):
115+
return self._row_splits_dtype
116+
117+
@row_splits_dtype.setter
118+
def row_splits_dtype(self, value):
119+
raise AttributeError(
120+
"The `row_splits_dtype` attribute of KerasTensor is immutable. One "
121+
"should create a new instance of KerasTensor for this."
122+
)
123+
86124
@property
87125
def ragged(self):
88126
return self._ragged

keras/src/backend/jax/trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,10 +1026,12 @@ def _get_iterator(self):
10261026
distribution = distribution_lib.distribution()
10271027
if distribution is not None:
10281028
return self._get_distributed_iterator(distribution)
1029-
1030-
return self._prefetch_numpy_iterator(
1031-
self.data_adapter.get_jax_iterator()
1032-
)
1029+
if self.data_adapter.builtin_prefetch:
1030+
return self.data_adapter.get_jax_iterator()
1031+
else:
1032+
return self._prefetch_numpy_iterator(
1033+
self.data_adapter.get_jax_iterator()
1034+
)
10331035

10341036
def _get_distributed_iterator(self, distribution):
10351037
"""Lazily compute layouts to reduce host to device transfer latency."""

keras/src/trainers/data_adapters/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from keras.src.trainers.data_adapters.generator_data_adapter import (
99
GeneratorDataAdapter,
1010
)
11+
from keras.src.trainers.data_adapters.grain_dataset_adapter import (
12+
GrainDatasetAdapter,
13+
)
1114
from keras.src.trainers.data_adapters.py_dataset_adapter import PyDatasetAdapter
1215
from keras.src.trainers.data_adapters.tf_dataset_adapter import TFDatasetAdapter
1316
from keras.src.trainers.data_adapters.torch_data_loader_adapter import (
@@ -111,6 +114,32 @@ def get_data_adapter(
111114
# "data `x` was provided as a torch DataLoader. The DataLoader "
112115
# "is expected to already be shuffled."
113116
# )
117+
elif is_grain_dataset(x):
118+
if y is not None:
119+
raise_unsupported_arg(
120+
"y", "the targets", "grain.Dataset and grain.DataLoader"
121+
)
122+
if sample_weight is not None:
123+
raise_unsupported_arg(
124+
"sample_weights",
125+
"the sample weights",
126+
"grain.Dataset and grain.DataLoader",
127+
)
128+
if class_weight is not None:
129+
raise ValueError(
130+
"Argument `class_weight` is not supported for grain.Dataset "
131+
f"and grain.DataLoader inputs. You can modify your "
132+
"`__getitem__ ` method to return input tensor, label and "
133+
"class_weight. "
134+
f"Received: class_weight={class_weight}"
135+
)
136+
return GrainDatasetAdapter(x)
137+
# TODO: should we warn or not?
138+
# warnings.warn(
139+
# "`shuffle=True` was passed, but will be ignored since the "
140+
# "data `x` was provided as a grain dataset. The grain dataset "
141+
# "is expected to already be shuffled."
142+
# )
114143
elif isinstance(x, types.GeneratorType):
115144
if y is not None:
116145
raise_unsupported_arg("y", "the targets", "PyDataset")
@@ -162,3 +191,15 @@ def is_torch_dataloader(x):
162191
):
163192
return True
164193
return False
194+
195+
196+
def is_grain_dataset(x):
197+
if hasattr(x, "__class__"):
198+
for parent in x.__class__.__mro__:
199+
if parent.__name__ in (
200+
"MapDataset",
201+
"IterDataset",
202+
"DataLoader",
203+
) and "grain" in str(parent.__module__):
204+
return True
205+
return False

keras/src/trainers/data_adapters/data_adapter.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ def get_torch_dataloader(self):
4646
"""
4747
raise NotImplementedError
4848

49+
@property
50+
def builtin_prefetch(self):
51+
"""Whether the DataAdapter has built-in prefetching capabilities.
52+
53+
Prefetching is an optimization technique where data is loaded and
54+
prepared in advance while the model is processing the current batch,
55+
reducing training time by overlapping data loading with computation.
56+
57+
Returns:
58+
bool: True if the DataAdapter implements its own prefetching
59+
mechanism and handles data loading asynchronously. False if the
60+
caller should implement prefetching externally.
61+
"""
62+
return False
63+
4964
@property
5065
def num_batches(self):
5166
"""Return the size (number of batches) for the dataset created.

keras/src/trainers/data_adapters/data_adapter_utils.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,25 @@ def class_weight_to_sample_weights(y, class_weight):
133133
return sample_weight
134134

135135

136-
def get_tensor_spec(batches):
137-
"""Return the common tensor spec for a list of batches.
136+
def get_keras_tensor_spec(batches):
137+
"""Return the KerasTensor spec for a list of batches.
138+
139+
The spec is represented using `KerasTensor` which could handle dense, sparse
140+
or ragged tensors.
138141
139142
Args:
140143
batches: list of structures of tensors. The structures must be
141144
identical, but the shape at each leaf may be different.
142-
Returns: the common tensor spec for all the batches.
145+
146+
Returns:
147+
A nested structure of `KerasTensor`.
143148
"""
144-
from keras.src.utils.module_utils import tensorflow as tf
145149

146150
def get_single_tensor_spec(*tensors):
147151
x = tensors[0]
152+
if not hasattr(x, "shape"):
153+
# Try to convert to a numpy array.
154+
x = np.array(x)
148155
rank = len(x.shape)
149156
if rank < 1:
150157
raise ValueError(
@@ -164,28 +171,72 @@ def get_single_tensor_spec(*tensors):
164171
for dims in zip(*[list(x.shape) for x in tensors]):
165172
dims_set = set(dims)
166173
shape.append(dims_set.pop() if len(dims_set) == 1 else None)
167-
shape[0] = None # batch size may not be static
168174

169175
dtype = backend.standardize_dtype(x.dtype)
170-
if isinstance(x, tf.RaggedTensor):
171-
return tf.RaggedTensorSpec(
176+
if is_tensorflow_ragged(x):
177+
return backend.KerasTensor(
172178
shape=shape,
173179
dtype=dtype,
180+
ragged=True,
174181
ragged_rank=x.ragged_rank,
175182
row_splits_dtype=x.row_splits.dtype,
176183
)
177-
if (
178-
isinstance(x, tf.SparseTensor)
179-
or is_scipy_sparse(x)
180-
or is_jax_sparse(x)
181-
):
182-
return tf.SparseTensorSpec(shape=shape, dtype=dtype)
184+
if is_tensorflow_sparse(x) or is_scipy_sparse(x) or is_jax_sparse(x):
185+
return backend.KerasTensor(shape=shape, dtype=dtype, sparse=True)
183186
else:
184-
return tf.TensorSpec(shape=shape, dtype=dtype)
187+
return backend.KerasTensor(shape=shape, dtype=dtype)
185188

186189
return tree.map_structure(get_single_tensor_spec, *batches)
187190

188191

192+
def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):
193+
"""Convert a KerasTensor to a TensorSpec.
194+
195+
Args:
196+
keras_tensor: A KerasTensor instance.
197+
batch_axis_to_none: If `True`, the batch axis of the returned
198+
tensor spec will be set to None. Defaults to `True`.
199+
"""
200+
from keras.src.utils.module_utils import tensorflow as tf
201+
202+
if not isinstance(keras_tensor, backend.KerasTensor):
203+
raise TypeError(
204+
f"Expected a KerasTensor, but got {keras_tensor} of type "
205+
f"{type(keras_tensor)}."
206+
)
207+
shape = list(keras_tensor.shape)
208+
if batch_axis_to_none:
209+
shape[0] = None
210+
if keras_tensor.ragged:
211+
return tf.RaggedTensorSpec(
212+
shape=shape,
213+
dtype=keras_tensor.dtype,
214+
ragged_rank=keras_tensor.ragged_rank,
215+
row_splits_dtype=keras_tensor.row_splits_dtype,
216+
)
217+
elif keras_tensor.sparse:
218+
return tf.SparseTensorSpec(shape=shape, dtype=keras_tensor.dtype)
219+
else:
220+
return tf.TensorSpec(shape=shape, dtype=keras_tensor.dtype)
221+
222+
223+
def get_tensor_spec(batches):
224+
"""Return the common tensor spec for a list of batches.
225+
226+
The spec is represented using `tf.TensorSpec`, `tf.SparseTensorSpec` and
227+
`tf.RaggedTensorSpec`.
228+
229+
Args:
230+
batches: list of structures of tensors. The structures must be
231+
identical, but the shape at each leaf may be different.
232+
233+
Returns:
234+
A common tensor spec.
235+
"""
236+
tensor_specs = get_keras_tensor_spec(batches)
237+
return tree.map_structure(convert_to_tf_tensor_spec, tensor_specs)
238+
239+
189240
def get_jax_iterator(iterable):
190241
import jax
191242
import jax.experimental.sparse as jax_sparse

0 commit comments

Comments
 (0)