@@ -212,7 +212,7 @@ def __init__(self, generate_examples_fn: Callable[..., tuple[Key, dict]], kwargs
212212 self .kwargs = kwargs
213213
214214 def _init_state_dict (self ) -> dict :
215- self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 }
215+ self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self . __class__ . __name__ }
216216 return self ._state_dict
217217
218218 def __iter__ (self ):
@@ -250,7 +250,7 @@ def __init__(
250250 self .generator = deepcopy (generator )
251251
252252 def _init_state_dict (self ) -> dict :
253- self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 }
253+ self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self . __class__ . __name__ }
254254 return self ._state_dict
255255
256256 def __iter__ (self ):
@@ -290,7 +290,7 @@ def iter_arrow(self):
290290 return self ._iter_arrow
291291
292292 def _init_state_dict (self ) -> dict :
293- self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 }
293+ self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self . __class__ . __name__ }
294294 return self ._state_dict
295295
296296 def __iter__ (self ):
@@ -357,7 +357,7 @@ def __init__(
357357 self .generator = deepcopy (generator )
358358
359359 def _init_state_dict (self ) -> dict :
360- self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 }
360+ self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self . __class__ . __name__ }
361361 return self ._state_dict
362362
363363 def __iter__ (self ):
@@ -437,11 +437,12 @@ def features(self):
437437
438438 def _init_state_dict (self ) -> dict :
439439 self ._state_dict = {
440- "ex_iterable " : self .ex_iterable ._init_state_dict (),
440+ "examples_iterable " : self .ex_iterable ._init_state_dict (),
441441 "previous_state" : None ,
442442 "batch_idx" : 0 ,
443443 "num_chunks_since_previous_state" : 0 ,
444444 "cropped_chunk_length" : 0 ,
445+ "type" : self .__class__ .__name__ ,
445446 }
446447 return self ._state_dict
447448
@@ -680,6 +681,7 @@ def _init_state_dict(self) -> dict:
680681 "ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ],
681682 "previous_states" : [None ] * len (self .ex_iterables ),
682683 "is_exhausted" : [False ] * len (self .ex_iterables ),
684+ "type" : self .__class__ .__name__ ,
683685 }
684686 return self ._state_dict
685687
@@ -778,6 +780,7 @@ def _init_state_dict(self) -> dict:
778780 self ._state_dict = {
779781 "ex_iterable_idx" : 0 ,
780782 "ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ],
783+ "type" : self .__class__ .__name__ ,
781784 }
782785 return self ._state_dict
783786
@@ -858,7 +861,10 @@ def features(self):
858861 return self .ex_iterables [0 ].features
859862
860863 def _init_state_dict (self ) -> dict :
861- self ._state_dict = {"ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ]}
864+ self ._state_dict = {
865+ "ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ],
866+ "type" : self .__class__ .__name__ ,
867+ }
862868 return self ._state_dict
863869
864870 def __iter__ (self ):
@@ -960,6 +966,7 @@ def _init_state_dict(self) -> dict:
960966 "ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ],
961967 "previous_states" : [None ] * len (self .ex_iterables ),
962968 "is_exhausted" : [False ] * len (self .ex_iterables ),
969+ "type" : self .__class__ .__name__ ,
963970 }
964971 return self ._state_dict
965972
@@ -1060,10 +1067,11 @@ def features(self):
10601067
10611068 def _init_state_dict (self ) -> dict :
10621069 self ._state_dict = {
1063- "ex_iterable " : self .ex_iterable ._init_state_dict (),
1070+ "examples_iterable " : self .ex_iterable ._init_state_dict (),
10641071 "previous_state" : None ,
10651072 "num_examples_since_previous_state" : 0 ,
10661073 "previous_state_example_idx" : 0 ,
1074+ "type" : self .__class__ .__name__ ,
10671075 }
10681076 return self ._state_dict
10691077
@@ -1578,7 +1586,11 @@ def features(self):
15781586 return self .ex_iterable .features
15791587
15801588 def _init_state_dict (self ) -> dict :
1581- self ._state_dict = {"skipped" : False , "ex_iterable" : self .ex_iterable ._init_state_dict ()}
1589+ self ._state_dict = {
1590+ "skipped" : False ,
1591+ "examples_iterable" : self .ex_iterable ._init_state_dict (),
1592+ "type" : self .__class__ .__name__ ,
1593+ }
15821594 return self ._state_dict
15831595
15841596 def __iter__ (self ):
@@ -1642,7 +1654,8 @@ def __init__(
16421654 def _init_state_dict (self ) -> dict :
16431655 self ._state_dict = {
16441656 "repeat_index" : 0 ,
1645- "ex_iterable" : self .ex_iterable ._init_state_dict (),
1657+ "examples_iterable" : self .ex_iterable ._init_state_dict (),
1658+ "type" : self .__class__ .__name__ ,
16461659 }
16471660 return self ._state_dict
16481661
@@ -1655,7 +1668,7 @@ def __iter__(self):
16551668 repeat_index += 1
16561669 if self ._state_dict :
16571670 self ._state_dict ["repeat_index" ] = repeat_index
1658- self ._state_dict ["ex_iterable " ] = self .ex_iterable ._init_state_dict ()
1671+ self ._state_dict ["examples_iterable " ] = self .ex_iterable ._init_state_dict ()
16591672
16601673 def shuffle_data_sources (self , generator : np .random .Generator ) -> "RepeatExamplesIterable" :
16611674 """Shuffle the underlying iterable, then repeat."""
@@ -1697,7 +1710,11 @@ def features(self):
16971710 return self .ex_iterable .features
16981711
16991712 def _init_state_dict (self ) -> dict :
1700- self ._state_dict = {"num_taken" : 0 , "ex_iterable" : self .ex_iterable ._init_state_dict ()}
1713+ self ._state_dict = {
1714+ "num_taken" : 0 ,
1715+ "examples_iterable" : self .ex_iterable ._init_state_dict (),
1716+ "type" : self .__class__ .__name__ ,
1717+ }
17011718 return self ._state_dict
17021719
17031720 def __iter__ (self ):
@@ -1956,9 +1973,8 @@ def __init__(
19561973 self ._token_per_repo_id : dict [str , Union [str , bool , None ]] = token_per_repo_id or {}
19571974 self ._epoch : Union [int , "torch.Tensor" ] = _maybe_share_with_torch_persistent_workers (0 )
19581975 self ._starting_state_dict : Optional [dict ] = None
1959- self ._prepared_ex_iterable = self ._prepare_ex_iterable_for_iteration ()
1960- self ._state_dict = self ._prepared_ex_iterable ._init_state_dict ()
1961- _maybe_add_torch_iterable_dataset_parent_class (self .__class__ )
1976+ self ._prepare_ex_iterable_for_iteration () # set state_dict
1977+ _maybe_add_torch_iterable_dataset_parent_class (self .__class__ ) # subclass of torch IterableDataset
19621978
19631979 def state_dict (self ) -> dict :
19641980 """Get the current state_dict of the dataset.
@@ -2061,7 +2077,6 @@ def load_state_dict(self, state_dict: dict) -> None:
20612077 >>> dataloader.load_state_dict(state_dict) # uses ds.load_state_dict() under the hood
20622078 ```
20632079 """
2064- self ._prepared_ex_iterable .load_state_dict (state_dict )
20652080 self ._starting_state_dict = state_dict
20662081
20672082 def __repr__ (self ):
@@ -2136,9 +2151,12 @@ def _iter_pytorch(self):
21362151 ex_iterable = ex_iterable .shard_data_sources (
21372152 num_shards = worker_info .num_workers , index = worker_info .id , contiguous = False
21382153 )
2139- self ._state_dict = ex_iterable ._init_state_dict ()
2140- if self ._starting_state_dict :
2141- ex_iterable .load_state_dict (self ._starting_state_dict )
2154+ self ._state_dict = {
2155+ "examples_iterable" : ex_iterable ._init_state_dict (),
2156+ "epoch" : self .epoch ,
2157+ }
2158+ if self ._starting_state_dict and self .epoch == self ._starting_state_dict ["epoch" ]:
2159+ ex_iterable .load_state_dict (self ._starting_state_dict ["examples_iterable" ])
21422160
21432161 if self ._formatting and (ex_iterable .iter_arrow or self ._formatting .is_table ):
21442162 formatter = get_formatter (self ._formatting .format_type , features = self .features )
@@ -2216,9 +2234,12 @@ def _prepare_ex_iterable_for_iteration(
22162234 token_per_repo_id = self ._token_per_repo_id ,
22172235 )
22182236
2219- self ._state_dict = ex_iterable ._init_state_dict ()
2220- if self ._starting_state_dict :
2221- ex_iterable .load_state_dict (self ._starting_state_dict )
2237+ self ._state_dict = {
2238+ "examples_iterable" : ex_iterable ._init_state_dict (),
2239+ "epoch" : self .epoch ,
2240+ }
2241+ if self ._starting_state_dict and self .epoch == self ._starting_state_dict ["epoch" ]:
2242+ ex_iterable .load_state_dict (self ._starting_state_dict ["examples_iterable" ])
22222243 return ex_iterable
22232244
22242245 def __iter__ (self ):
0 commit comments