8
8
9
9
from keras .src import tree
10
10
from keras .src .api_export import keras_export
11
+ from keras .src .utils import file_utils
11
12
from keras .src .utils import io_utils
13
+ from keras .src .utils .module_utils import grain
12
14
from keras .src .utils .module_utils import tensorflow as tf
13
15
14
16
@@ -299,6 +301,17 @@ def is_torch_dataset(dataset):
299
301
return False
300
302
301
303
304
+ def is_grain_dataset (dataset ):
305
+ if hasattr (dataset , "__class__" ):
306
+ for parent in dataset .__class__ .__mro__ :
307
+ if parent .__name__ in (
308
+ "MapDataset" ,
309
+ "IterDataset" ,
310
+ ) and str (parent .__module__ ).startswith ("grain._src.python" ):
311
+ return True
312
+ return False
313
+
314
+
302
315
def _rescale_dataset_split_sizes (left_size , right_size , total_length ):
303
316
"""Rescale the dataset split sizes.
304
317
@@ -476,6 +489,10 @@ def _get_type_spec(dataset):
476
489
from torch .utils .data import Dataset as TorchDataset
477
490
478
491
return TorchDataset
492
+ elif is_grain_dataset (dataset ):
493
+ from grain import MapDataset
494
+
495
+ return MapDataset
479
496
else :
480
497
return None
481
498
@@ -525,10 +542,17 @@ def index_directory(
525
542
- class_names: names of the classes corresponding to these labels, in
526
543
order.
527
544
"""
545
+ if file_utils .is_remote_path (directory ):
546
+ os_module = tf .io .gfile
547
+ path_module = tf .io .gfile
548
+ else :
549
+ os_module = os
550
+ path_module = os .path
551
+
528
552
if labels == "inferred" :
529
553
subdirs = []
530
- for subdir in sorted (tf . io . gfile .listdir (directory )):
531
- if tf . io . gfile . isdir (tf . io . gfile .join (directory , subdir )):
554
+ for subdir in sorted (os_module .listdir (directory )):
555
+ if path_module . isdir (path_module .join (directory , subdir )):
532
556
if not subdir .startswith ("." ):
533
557
if subdir .endswith ("/" ):
534
558
subdir = subdir [:- 1 ]
@@ -566,7 +590,7 @@ def index_directory(
566
590
results = []
567
591
filenames = []
568
592
569
- for dirpath in (tf . io . gfile .join (directory , subdir ) for subdir in subdirs ):
593
+ for dirpath in (path_module .join (directory , subdir ) for subdir in subdirs ):
570
594
results .append (
571
595
pool .apply_async (
572
596
index_subdirectory ,
@@ -608,7 +632,7 @@ def index_directory(
608
632
)
609
633
pool .close ()
610
634
pool .join ()
611
- file_paths = [tf . io . gfile .join (directory , fname ) for fname in filenames ]
635
+ file_paths = [path_module .join (directory , fname ) for fname in filenames ]
612
636
613
637
if shuffle :
614
638
# Shuffle globally to erase macro-structure
@@ -623,8 +647,10 @@ def index_directory(
623
647
624
648
625
649
def iter_valid_files (directory , follow_links , formats ):
650
+ io_module = tf .io .gfile if file_utils .is_remote_path (directory ) else os
651
+
626
652
if not follow_links :
627
- walk = tf . io . gfile .walk (directory )
653
+ walk = io_module .walk (directory )
628
654
else :
629
655
walk = os .walk (directory , followlinks = follow_links )
630
656
for root , _ , files in sorted (walk , key = lambda x : x [0 ]):
@@ -648,14 +674,18 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
648
674
paths, and `labels` is a list of integer labels corresponding
649
675
to these files.
650
676
"""
677
+ path_module = (
678
+ tf .io .gfile if file_utils .is_remote_path (directory ) else os .path
679
+ )
680
+
651
681
dirname = os .path .basename (directory )
652
682
valid_files = iter_valid_files (directory , follow_links , formats )
653
683
labels = []
654
684
filenames = []
655
685
for root , fname in valid_files :
656
686
labels .append (class_indices [dirname ])
657
- absolute_path = tf . io . gfile .join (root , fname )
658
- relative_path = tf . io . gfile .join (
687
+ absolute_path = path_module .join (root , fname )
688
+ relative_path = path_module .join (
659
689
dirname , os .path .relpath (absolute_path , directory )
660
690
)
661
691
filenames .append (relative_path )
@@ -700,7 +730,7 @@ def get_training_or_validation_split(samples, labels, validation_split, subset):
700
730
return samples , labels
701
731
702
732
703
- def labels_to_dataset (labels , label_mode , num_classes ):
733
+ def labels_to_dataset_tf (labels , label_mode , num_classes ):
704
734
"""Create a `tf.data.Dataset` from the list/tuple of labels.
705
735
706
736
Args:
@@ -730,6 +760,51 @@ def labels_to_dataset(labels, label_mode, num_classes):
730
760
return label_ds
731
761
732
762
763
+ def labels_to_dataset_grain (labels , label_mode , num_classes ):
764
+ """Create a `grain.MapDataset` from the list/tuple of labels.
765
+
766
+ Args:
767
+ labels: list/tuple of labels to be converted into a `grain.MapDataset`.
768
+ label_mode: String describing the encoding of `labels`. Options are:
769
+ - `"binary"` indicates that the labels (there can be only 2) are encoded
770
+ as `float32` scalars with values 0 or 1
771
+ (e.g. for `binary_crossentropy`).
772
+ - `"categorical"` means that the labels are mapped into a categorical
773
+ vector. (e.g. for `categorical_crossentropy` loss).
774
+ num_classes: number of classes of labels.
775
+
776
+ Returns:
777
+ A `grain.MapDataset` instance.
778
+ """
779
+ from keras .src import backend
780
+ from keras .src import ops
781
+
782
+ if label_mode not in ("binary" , "categorical" , "int" ):
783
+ raise ValueError (
784
+ f"Invalid `label_mode`: { label_mode } . "
785
+ "Expected one of: 'binary', 'categorical', 'int'."
786
+ )
787
+
788
+ def preprocess_labels_in_cpu (label_mode , x , num_classes ):
789
+ with backend .device_scope ("cpu" ):
790
+ if label_mode == "binary" :
791
+ return ops .expand_dims (
792
+ ops .convert_to_tensor (x , dtype = "float32" ), axis = - 1
793
+ )
794
+ elif label_mode == "categorical" :
795
+ return ops .one_hot (
796
+ ops .convert_to_tensor (x , dtype = "int32" ), num_classes
797
+ )
798
+ else :
799
+ return ops .convert_to_tensor (x , dtype = "int32" )
800
+
801
+ label_ds = grain .MapDataset .source (labels )
802
+ label_ds = label_ds .map (
803
+ lambda x : preprocess_labels_in_cpu (label_mode , x , num_classes ),
804
+ )
805
+ return label_ds
806
+
807
+
733
808
def check_validation_split_arg (validation_split , subset , shuffle , seed ):
734
809
"""Raise errors in case of invalid argument values.
735
810
0 commit comments