4646 tf .data .Dataset ,
4747 tuple [tf .data .Dataset , tf .data .Dataset ],
4848 tuple [tf .data .Dataset , Mapping [str , tf .data .Dataset ]],
49- iterator .DatasetIterator ,
50- tuple [iterator .DatasetIterator , iterator .DatasetIterator ],
51- tuple [iterator .DatasetIterator , Mapping [str , iterator .DatasetIterator ]],
49+ iterator .Iterator ,
50+ tuple [iterator .Iterator , iterator .Iterator ],
51+ tuple [iterator .Iterator , Mapping [str , iterator .Iterator ]],
5252)
5353MetaT = TypeVar ("MetaT" )
5454Logs = Any # Any metric logs returned by the training or evaluation task.
@@ -120,9 +120,9 @@ def run_experiment(
120120
121121def get_iterators (
122122 datasets : DatasetT ,
123- ) -> tuple [iterator .DatasetIterator , Mapping [str , iterator .DatasetIterator ]]:
123+ ) -> tuple [iterator .Iterator , Mapping [str , iterator .Iterator ]]:
124124 """Creates and unpacks the datasets returned by the task."""
125- if isinstance (datasets , (iterator .DatasetIterator , tf .data .Dataset )):
125+ if isinstance (datasets , (iterator .Iterator , tf .data .Dataset )):
126126 if isinstance (datasets , tf .data .Dataset ):
127127 datasets = iterator .TFDatasetIterator (datasets )
128128 return datasets , {}
@@ -133,7 +133,7 @@ def get_iterators(
133133 )
134134
135135 train_dataset , eval_datasets = datasets
136- if isinstance (train_dataset , (iterator .DatasetIterator , tf .data .Dataset )):
136+ if isinstance (train_dataset , (iterator .Iterator , tf .data .Dataset )):
137137 if isinstance (train_dataset , tf .data .Dataset ):
138138 train_dataset = iterator .TFDatasetIterator (train_dataset )
139139 else :
@@ -143,7 +143,7 @@ def get_iterators(
143143 f" { type (train_dataset )} ."
144144 )
145145
146- if isinstance (eval_datasets , (iterator .DatasetIterator , tf .data .Dataset )):
146+ if isinstance (eval_datasets , (iterator .Iterator , tf .data .Dataset )):
147147 if isinstance (eval_datasets , tf .data .Dataset ):
148148 eval_datasets = iterator .TFDatasetIterator (eval_datasets )
149149 return train_dataset , {"" : eval_datasets }
@@ -162,7 +162,7 @@ def get_iterators(
162162 }
163163
164164 if not all (
165- isinstance (v , iterator .DatasetIterator ) for v in eval_datasets .values ()
165+ isinstance (v , iterator .Iterator ) for v in eval_datasets .values ()
166166 ):
167167 raise ValueError (
168168 "Expected all values in the evaluation datasets mapping to be either"
0 commit comments