@@ -1076,15 +1076,17 @@ def _iter(self):
10761076 num_examples_to_skip = 0
10771077 iterator = iter (self .ex_iterable )
10781078
1079+ # We use the same logic as in Dataset.map, but with less features/formatting
1080+ # since they're handled by FormattedExamplesIterable
1081+
10791082 if self .formatting :
10801083 formatter = get_formatter (self .formatting .format_type )
1081- format_dict = (
1082- formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else cast_to_python_objects
1083- )
1084+ format_dict = formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else None
10841085 else :
10851086 format_dict = None
10861087
10871088 def iter_batched_inputs ():
1089+ nonlocal current_idx
10881090 for key , example in iterator :
10891091 # If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset
10901092 iterator_batch = (
@@ -1104,17 +1106,21 @@ def iter_batched_inputs():
11041106 ): # ignore last batch
11051107 return
11061108 batch = _examples_to_batch (examples )
1109+ # we need to format here in case we need to stack tensors together
11071110 batch = format_dict (batch ) if format_dict else batch
11081111 indices = [current_idx + i for i in range (len (key_examples_list ))]
1112+ current_idx += len (indices )
11091113 yield indices , (key , batch )
11101114
11111115 def iter_inputs ():
1116+ nonlocal current_idx
11121117 for key , example in iterator :
11131118 # If not batched, we can apply the transform and yield the example directly
11141119 # first copy the example, since we might drop some keys
11151120 example = dict (example )
1116- example = format_dict (example ) if format_dict else example
1117- yield current_idx , (key , example )
1121+ # no need to do formatting here
1122+ current_idx += 1
1123+ yield current_idx - 1 , (key , example )
11181124
11191125 def validate_function_output (processed_inputs ):
11201126 if self .batched and processed_inputs :
@@ -1147,17 +1153,7 @@ def prepare_outputs(key_example, inputs, processed_inputs):
11471153 if processed_inputs is key_example [1 ] and c in processed_inputs :
11481154 del processed_inputs [c ]
11491155 transformed_inputs = {** inputs , ** processed_inputs }
1150- if self .features :
1151- for c in self .features .keys ():
1152- if c not in transformed_inputs :
1153- transformed_inputs [c ] = (
1154- [None ] * len (transformed_inputs [next (iter (processed_inputs ))]) if self .batched else None
1155- )
1156- transformed_inputs = (
1157- self .features .decode_batch (transformed_inputs )
1158- if self .batched
1159- else self .features .decode_example (transformed_inputs )
1160- )
1156+ # no need to do features decoding here
11611157 return transformed_inputs
11621158
11631159 def apply_function (key_example , indices ):
@@ -1185,6 +1181,11 @@ def iter_outputs():
11851181 nonlocal tasks , loop
11861182 inputs_iterator = iter_batched_inputs () if self .batched else iter_inputs ()
11871183 if inspect .iscoroutinefunction (self .function ):
1184+ if self ._state_dict :
1185+ previous_state = self .ex_iterable .state_dict ()
1186+ self ._state_dict ["previous_state" ] = previous_state
1187+ previous_state_task = None
1188+ previous_state_example_idx = self ._state_dict ["previous_state_example_idx" ]
11881189 indices : Union [list [int ], list [list [int ]]] = []
11891190 for i , key_example in inputs_iterator :
11901191 indices .append (i )
@@ -1198,42 +1199,57 @@ def iter_outputs():
11981199 done , pending = loop .run_until_complete (
11991200 asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
12001201 )
1202+ if len (tasks ) >= 10 * config .MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL :
1203+ loop .run_until_complete (tasks [0 ])
12011204 # yield finished tasks
12021205 while tasks and tasks [0 ].done ():
1203- yield indices .pop (0 ), tasks .pop (0 ).result ()
1206+ i , task = indices .pop (0 ), tasks .pop (0 )
1207+ yield i , task .result ()
1208+ if self ._state_dict and task is previous_state_task :
1209+ self ._state_dict ["previous_state" ] = previous_state
1210+ self ._state_dict ["num_examples_since_previous_state" ] = 0
1211+ self ._state_dict ["previous_state_example_idx" ] = previous_state_example_idx
1212+ previous_state , previous_state_task = None , None
1213+ # checkpoint
1214+ if self ._state_dict and previous_state_task is None and tasks :
1215+ previous_state = self .ex_iterable .state_dict ()
1216+ previous_state_task = tasks [- 1 ]
1217+ previous_state_example_idx = current_idx
12041218 while tasks :
12051219 yield indices [0 ], loop .run_until_complete (tasks [0 ])
12061220 indices .pop (0 ), tasks .pop (0 )
12071221 else :
1208- for i , key_example in inputs_iterator :
1209- yield i , apply_function (key_example , i )
1210-
1211- try :
1212- if self .batched :
12131222 if self ._state_dict :
1214- self ._state_dict ["previous_state" ] = self .ex_iterable .state_dict ()
1215- self ._state_dict ["num_examples_since_previous_state" ] = 0
1216- self ._state_dict ["previous_state_example_idx" ] = current_idx
1217- for key , transformed_batch in iter_outputs ():
1218- # yield one example at a time from the transformed batch
1219- for example in _batch_to_examples (transformed_batch ):
1220- current_idx += 1
1221- if self ._state_dict :
1222- self ._state_dict ["num_examples_since_previous_state" ] += 1
1223- if num_examples_to_skip > 0 :
1224- num_examples_to_skip -= 1
1225- continue
1226- yield key , example
1227- if self ._state_dict :
1223+ if self .batched :
12281224 self ._state_dict ["previous_state" ] = self .ex_iterable .state_dict ()
12291225 self ._state_dict ["num_examples_since_previous_state" ] = 0
12301226 self ._state_dict ["previous_state_example_idx" ] = current_idx
1231- else :
1232- for key , transformed_example in iter_outputs ():
1233- current_idx += 1
1227+ for i , key_example in inputs_iterator :
12341228 if self ._state_dict :
1235- self ._state_dict ["previous_state_example_idx" ] += 1
1236- yield key , transformed_example
1229+ if not self .batched :
1230+ self ._state_dict ["previous_state_example_idx" ] = current_idx
1231+ yield i , apply_function (key_example , i )
1232+ if self ._state_dict :
1233+ if self .batched :
1234+ self ._state_dict ["previous_state" ] = self .ex_iterable .state_dict ()
1235+ self ._state_dict ["num_examples_since_previous_state" ] = 0
1236+ self ._state_dict ["previous_state_example_idx" ] = current_idx
1237+
1238+ try :
1239+ outputs = iter_outputs ()
1240+ if self .batched :
1241+ outputs = (
1242+ (key , transformed_example )
1243+ for key , transformed_batch in outputs
1244+ for transformed_example in _batch_to_examples (transformed_batch )
1245+ )
1246+ for key , transformed_example in outputs :
1247+ if self ._state_dict and self ._state_dict ["previous_state" ] is not None :
1248+ self ._state_dict ["num_examples_since_previous_state" ] += 1
1249+ if num_examples_to_skip > 0 :
1250+ num_examples_to_skip -= 1
1251+ continue
1252+ yield key , transformed_example
12371253 except (Exception , KeyboardInterrupt ):
12381254 if loop :
12391255 logger .debug (f"Canceling { len (tasks )} async tasks." )
@@ -1800,7 +1816,7 @@ def _init_state_dict(self) -> dict:
18001816
18011817 def __iter__ (self ):
18021818 if not self .formatting or self .formatting .is_table :
1803- formatter = PythonFormatter ()
1819+ formatter = PythonFormatter (features = self . _features if not self . ex_iterable . is_typed else None )
18041820 else :
18051821 formatter = get_formatter (
18061822 self .formatting .format_type ,
@@ -1817,15 +1833,17 @@ def __iter__(self):
18171833 format_dict = (
18181834 formatter .recursive_tensorize
18191835 if isinstance (formatter , TensorFormatter )
1820- else cast_to_python_objects # cast in case features is None
1836+ else None # cast in case features is None
18211837 )
18221838 for key , example in self .ex_iterable :
18231839 # don't apply feature types if already applied by ex_iterable (e.g. in case of chained with_format)
18241840 if self .features and not self .ex_iterable .is_typed :
18251841 example = _apply_feature_types_on_example (
18261842 example , self .features , token_per_repo_id = self .token_per_repo_id
18271843 )
1828- yield key , format_dict (example )
1844+ if format_dict :
1845+ example = format_dict (example )
1846+ yield key , example
18291847
18301848 def _iter_arrow (self ) -> Iterator [tuple [Key , pa .Table ]]:
18311849 if not self .features :
@@ -2049,7 +2067,7 @@ def __setstate__(self, d):
20492067 _maybe_add_torch_iterable_dataset_parent_class (self .__class__ )
20502068
20512069 def _head (self , n = 5 ):
2052- return _examples_to_batch ( list (self .take ( n )))
2070+ return next ( iter (self .iter ( batch_size = n )))
20532071
20542072 @property
20552073 def epoch (self ) -> int :
@@ -2111,15 +2129,8 @@ def _iter_pytorch(self):
21112129 if self ._starting_state_dict :
21122130 ex_iterable .load_state_dict (self ._starting_state_dict )
21132131
2114- if self ._formatting :
2115- formatter = get_formatter (self ._formatting .format_type , features = self .features )
2116- format_dict = (
2117- formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else cast_to_python_objects
2118- )
2119- else :
2120- format_dict = None
2121-
21222132 if self ._formatting and (ex_iterable .iter_arrow or self ._formatting .is_table ):
2133+ formatter = get_formatter (self ._formatting .format_type , features = self .features )
21232134 if ex_iterable .iter_arrow :
21242135 iterator = ex_iterable .iter_arrow ()
21252136 else :
@@ -2129,13 +2140,8 @@ def _iter_pytorch(self):
21292140 return
21302141 else :
21312142 for key , example in ex_iterable :
2132- if self .features and not ex_iterable .is_typed :
2133- # `IterableDataset` automatically fills missing columns with None.
2134- # This is done with `_apply_feature_types_on_example`.
2135- example = _apply_feature_types_on_example (
2136- example , self .features , token_per_repo_id = self ._token_per_repo_id
2137- )
2138- yield format_dict (example ) if format_dict else example
2143+ # no need to format thanks to FormattedExamplesIterable
2144+ yield example
21392145 logger .debug (
21402146 f"{ _log_prefix } dataloader worker#{ worker_info .id } , ': Finished iterating over { len (shards_indices )} /{ ex_iterable .num_shards } shards."
21412147 )
@@ -2191,6 +2197,14 @@ def _prepare_ex_iterable_for_iteration(
21912197 )
21922198 ex_iterable = StepExamplesIterable (ex_iterable , step = world_size , offset = rank )
21932199
2200+ if self ._formatting or (self .features and ex_iterable .features != self .features ):
2201+ ex_iterable = FormattedExamplesIterable (
2202+ ex_iterable ,
2203+ formatting = self ._formatting ,
2204+ features = self .features ,
2205+ token_per_repo_id = self ._token_per_repo_id ,
2206+ )
2207+
21942208 self ._state_dict = ex_iterable ._init_state_dict ()
21952209 if self ._starting_state_dict :
21962210 ex_iterable .load_state_dict (self ._starting_state_dict )
@@ -2207,15 +2221,8 @@ def __iter__(self):
22072221 return
22082222
22092223 ex_iterable = self ._prepare_ex_iterable_for_iteration ()
2210- if self ._formatting :
2211- formatter = get_formatter (self ._formatting .format_type , features = self .features )
2212- format_dict = (
2213- formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else cast_to_python_objects
2214- )
2215- else :
2216- format_dict = None
2217-
22182224 if self ._formatting and (ex_iterable .iter_arrow or self ._formatting .is_table ):
2225+ formatter = get_formatter (self ._formatting .format_type , features = self .features )
22192226 if ex_iterable .iter_arrow :
22202227 iterator = ex_iterable .iter_arrow ()
22212228 else :
@@ -2225,13 +2232,8 @@ def __iter__(self):
22252232 return
22262233
22272234 for key , example in ex_iterable :
2228- if self .features and not ex_iterable .is_typed :
2229- # `IterableDataset` automatically fills missing columns with None.
2230- # This is done with `_apply_feature_types_on_example`.
2231- example = _apply_feature_types_on_example (
2232- example , self .features , token_per_repo_id = self ._token_per_repo_id
2233- )
2234- yield format_dict (example ) if format_dict else example
2235+ # no need to format thanks to FormattedExamplesIterable
2236+ yield example
22352237
22362238 def iter (self , batch_size : int , drop_last_batch : bool = False ):
22372239 """Iterate through the batches of size `batch_size`.
@@ -2244,9 +2246,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
22442246
22452247 if self ._formatting :
22462248 formatter = get_formatter (self ._formatting .format_type , features = self .features )
2247- format_dict = (
2248- formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else cast_to_python_objects
2249- )
2249+ format_dict = formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else None
22502250 else :
22512251 format_dict = None
22522252
@@ -2267,10 +2267,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
22672267 if drop_last_batch and len (examples ) < batch_size : # ignore last batch
22682268 return
22692269 batch = _examples_to_batch (examples )
2270- if self .features and not ex_iterable .is_typed :
2271- # `IterableDataset` automatically fills missing columns with None.
2272- # This is done with `_apply_feature_types_on_batch`.
2273- batch = _apply_feature_types_on_batch (batch , self .features , token_per_repo_id = self ._token_per_repo_id )
2270+ # we need to format here in case we need to stack tensors together
22742271 yield format_dict (batch ) if format_dict else batch
22752272
22762273 @staticmethod
@@ -3241,7 +3238,13 @@ def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableData
32413238 def batch_fn (unbatched ):
32423239 return {k : [v ] for k , v in unbatched .items ()}
32433240
3244- return self .map (batch_fn , batched = True , batch_size = batch_size , drop_last_batch = drop_last_batch )
3241+ if self .features :
3242+ features = Features ({col : [feature ] for col , feature in self .features .items ()})
3243+ else :
3244+ features = None
3245+ return self .map (
3246+ batch_fn , batched = True , batch_size = batch_size , drop_last_batch = drop_last_batch , features = features
3247+ )
32453248
32463249
32473250def _concatenate_iterable_datasets (
0 commit comments