Skip to content

Commit 7da416d

Browse files
Add Grain support to image_dataset_from_directory and text_dataset_from_directory (#21593)
* Add Grain support to `image_dataset_from_directory` and `text_dataset_from_directory`. * Fix channels_first bug. * Refine the docstrings.
1 parent 693764a commit 7da416d

8 files changed

+868
-185
lines changed

keras/src/utils/audio_dataset_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def paths_and_labels_to_dataset(
411411
"""Constructs a fixed-size dataset of audio and labels."""
412412
path_ds = tf.data.Dataset.from_tensor_slices(file_paths)
413413
if label_mode:
414-
label_ds = dataset_utils.labels_to_dataset(
414+
label_ds = dataset_utils.labels_to_dataset_tf(
415415
labels, label_mode, num_classes
416416
)
417417
ds = tf.data.Dataset.zip((path_ds, label_ds))

keras/src/utils/dataset_utils.py

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from keras.src import tree
1010
from keras.src.api_export import keras_export
11+
from keras.src.utils import file_utils
1112
from keras.src.utils import io_utils
13+
from keras.src.utils.module_utils import grain
1214
from keras.src.utils.module_utils import tensorflow as tf
1315

1416

@@ -299,6 +301,17 @@ def is_torch_dataset(dataset):
299301
return False
300302

301303

304+
def is_grain_dataset(dataset):
305+
if hasattr(dataset, "__class__"):
306+
for parent in dataset.__class__.__mro__:
307+
if parent.__name__ in (
308+
"MapDataset",
309+
"IterDataset",
310+
) and str(parent.__module__).startswith("grain._src.python"):
311+
return True
312+
return False
313+
314+
302315
def _rescale_dataset_split_sizes(left_size, right_size, total_length):
303316
"""Rescale the dataset split sizes.
304317
@@ -476,6 +489,10 @@ def _get_type_spec(dataset):
476489
from torch.utils.data import Dataset as TorchDataset
477490

478491
return TorchDataset
492+
elif is_grain_dataset(dataset):
493+
from grain import MapDataset
494+
495+
return MapDataset
479496
else:
480497
return None
481498

@@ -525,10 +542,17 @@ def index_directory(
525542
- class_names: names of the classes corresponding to these labels, in
526543
order.
527544
"""
545+
if file_utils.is_remote_path(directory):
546+
os_module = tf.io.gfile
547+
path_module = tf.io.gfile
548+
else:
549+
os_module = os
550+
path_module = os.path
551+
528552
if labels == "inferred":
529553
subdirs = []
530-
for subdir in sorted(tf.io.gfile.listdir(directory)):
531-
if tf.io.gfile.isdir(tf.io.gfile.join(directory, subdir)):
554+
for subdir in sorted(os_module.listdir(directory)):
555+
if path_module.isdir(path_module.join(directory, subdir)):
532556
if not subdir.startswith("."):
533557
if subdir.endswith("/"):
534558
subdir = subdir[:-1]
@@ -566,7 +590,7 @@ def index_directory(
566590
results = []
567591
filenames = []
568592

569-
for dirpath in (tf.io.gfile.join(directory, subdir) for subdir in subdirs):
593+
for dirpath in (path_module.join(directory, subdir) for subdir in subdirs):
570594
results.append(
571595
pool.apply_async(
572596
index_subdirectory,
@@ -608,7 +632,7 @@ def index_directory(
608632
)
609633
pool.close()
610634
pool.join()
611-
file_paths = [tf.io.gfile.join(directory, fname) for fname in filenames]
635+
file_paths = [path_module.join(directory, fname) for fname in filenames]
612636

613637
if shuffle:
614638
# Shuffle globally to erase macro-structure
@@ -623,8 +647,10 @@ def index_directory(
623647

624648

625649
def iter_valid_files(directory, follow_links, formats):
650+
io_module = tf.io.gfile if file_utils.is_remote_path(directory) else os
651+
626652
if not follow_links:
627-
walk = tf.io.gfile.walk(directory)
653+
walk = io_module.walk(directory)
628654
else:
629655
walk = os.walk(directory, followlinks=follow_links)
630656
for root, _, files in sorted(walk, key=lambda x: x[0]):
@@ -648,14 +674,18 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
648674
paths, and `labels` is a list of integer labels corresponding
649675
to these files.
650676
"""
677+
path_module = (
678+
tf.io.gfile if file_utils.is_remote_path(directory) else os.path
679+
)
680+
651681
dirname = os.path.basename(directory)
652682
valid_files = iter_valid_files(directory, follow_links, formats)
653683
labels = []
654684
filenames = []
655685
for root, fname in valid_files:
656686
labels.append(class_indices[dirname])
657-
absolute_path = tf.io.gfile.join(root, fname)
658-
relative_path = tf.io.gfile.join(
687+
absolute_path = path_module.join(root, fname)
688+
relative_path = path_module.join(
659689
dirname, os.path.relpath(absolute_path, directory)
660690
)
661691
filenames.append(relative_path)
@@ -700,7 +730,7 @@ def get_training_or_validation_split(samples, labels, validation_split, subset):
700730
return samples, labels
701731

702732

703-
def labels_to_dataset(labels, label_mode, num_classes):
733+
def labels_to_dataset_tf(labels, label_mode, num_classes):
704734
"""Create a `tf.data.Dataset` from the list/tuple of labels.
705735
706736
Args:
@@ -730,6 +760,51 @@ def labels_to_dataset(labels, label_mode, num_classes):
730760
return label_ds
731761

732762

763+
def labels_to_dataset_grain(labels, label_mode, num_classes):
764+
"""Create a `grain.MapDataset` from the list/tuple of labels.
765+
766+
Args:
767+
labels: list/tuple of labels to be converted into a `grain.MapDataset`.
768+
label_mode: String describing the encoding of `labels`. Options are:
769+
- `"binary"` indicates that the labels (there can be only 2) are encoded
770+
as `float32` scalars with values 0 or 1
771+
(e.g. for `binary_crossentropy`).
772+
- `"categorical"` means that the labels are mapped into a categorical
773+
vector. (e.g. for `categorical_crossentropy` loss).
774+
num_classes: number of classes of labels.
775+
776+
Returns:
777+
A `grain.MapDataset` instance.
778+
"""
779+
from keras.src import backend
780+
from keras.src import ops
781+
782+
if label_mode not in ("binary", "categorical", "int"):
783+
raise ValueError(
784+
f"Invalid `label_mode`: {label_mode}. "
785+
"Expected one of: 'binary', 'categorical', 'int'."
786+
)
787+
788+
def preprocess_labels_in_cpu(label_mode, x, num_classes):
789+
with backend.device_scope("cpu"):
790+
if label_mode == "binary":
791+
return ops.expand_dims(
792+
ops.convert_to_tensor(x, dtype="float32"), axis=-1
793+
)
794+
elif label_mode == "categorical":
795+
return ops.one_hot(
796+
ops.convert_to_tensor(x, dtype="int32"), num_classes
797+
)
798+
else:
799+
return ops.convert_to_tensor(x, dtype="int32")
800+
801+
label_ds = grain.MapDataset.source(labels)
802+
label_ds = label_ds.map(
803+
lambda x: preprocess_labels_in_cpu(label_mode, x, num_classes),
804+
)
805+
return label_ds
806+
807+
733808
def check_validation_split_arg(validation_split, subset, shuffle, seed):
734809
"""Raise errors in case of invalid argument values.
735810

keras/src/utils/grain_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from keras.src import backend
2+
from keras.src import tree
3+
4+
5+
def make_batch(values):
6+
from keras.src import ops
7+
8+
if not values:
9+
raise ValueError("Cannot batch 0 values. Please file a bug.")
10+
11+
with backend.device_scope("cpu"):
12+
return tree.map_structure(lambda *xs: ops.stack(xs), *values)
13+
14+
15+
def make_string_batch(values):
16+
from keras.src import ops
17+
18+
if not values:
19+
raise ValueError("Cannot batch 0 values. Please file a bug.")
20+
21+
def batch_fn(*xs):
22+
if isinstance(xs[0], str):
23+
if backend.backend() == "tensorflow":
24+
import tensorflow as tf
25+
26+
xs = [tf.convert_to_tensor(x, dtype=tf.string) for x in xs]
27+
xs = tf.stack(xs)
28+
return xs
29+
else:
30+
return ops.stack(xs)
31+
32+
with backend.device_scope("cpu"):
33+
return tree.map_structure(batch_fn, *values)

0 commit comments

Comments
 (0)