@@ -871,9 +871,25 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res
871871 assert all (torch .equal (x , y ) for x , y in zip (batch , batches_1 [i ]))
872872 if i == break_at :
873873 break
874+ expected_3 = [
875+ [torch .tensor ([4 ]), torch .tensor ([4 ])],
876+ [torch .tensor ([9 ]), torch .tensor ([9 ])],
877+ [torch .tensor ([0 ]), torch .tensor ([0 ])],
878+ [torch .tensor ([5 ]), torch .tensor ([5 ])],
879+ ]
880+ for i , batch in enumerate (dloader ):
881+ if not shuffle :
882+ assert all (
883+ torch .equal (x , y )
884+ for x , y in zip (batch , (expected_3 if resume and length is not None else expected_1 )[i ])
885+ )
886+ elif not resume and length is not None :
887+ assert all (torch .equal (x , y ) for x , y in zip (batch , batches_1 [i ]))
888+ if i == break_at :
889+ break
874890
875891
876- @pytest .mark .parametrize ("length" , [None , 6 ])
892+ @pytest .mark .parametrize ("length" , [None , 5 ])
877893@pytest .mark .parametrize ("resume" , [False , True ])
878894@pytest .mark .parametrize ("shuffle" , [False , True ])
879895@pytest .mark .skipif (sys .platform in ("win32" , "darwin" ), reason = "too slow in CI" )
@@ -888,26 +904,39 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re
888904 [torch .tensor ([1 ]), torch .tensor ([1 ])],
889905 [torch .tensor ([3 ]), torch .tensor ([3 ])],
890906 [torch .tensor ([0 ]), torch .tensor ([0 ])],
891- [torch .tensor ([2 ]), torch .tensor ([2 ])],
892907 ]
893908 batches_1 = []
894909 for i , batch in enumerate (dloader ):
895910 if not shuffle :
896911 assert all (torch .equal (x , y ) for x , y in zip (batch , expected_1 [i ]))
897912 batches_1 .append (batch )
898913 expected_2 = [
914+ [torch .tensor ([1 ]), torch .tensor ([1 ])],
915+ [torch .tensor ([2 ]), torch .tensor ([2 ])],
916+ [torch .tensor ([0 ]), torch .tensor ([0 ])],
917+ [torch .tensor ([3 ]), torch .tensor ([3 ])],
918+ [torch .tensor ([1 ]), torch .tensor ([1 ])],
919+ ]
920+ for i , batch in enumerate (dloader ):
921+ if not shuffle :
922+ assert all (
923+ torch .equal (x , y )
924+ for x , y in zip (batch , (expected_2 if resume and length is not None else expected_1 )[i ])
925+ )
926+ elif not resume and length is not None :
927+ assert all (torch .equal (x , y ) for x , y in zip (batch , batches_1 [i ]))
928+ expected_3 = [
899929 [torch .tensor ([1 ]), torch .tensor ([1 ])],
900930 [torch .tensor ([3 ]), torch .tensor ([3 ])],
901931 [torch .tensor ([0 ]), torch .tensor ([0 ])],
902932 [torch .tensor ([2 ]), torch .tensor ([2 ])],
903933 [torch .tensor ([1 ]), torch .tensor ([1 ])],
904- [torch .tensor ([3 ]), torch .tensor ([3 ])],
905934 ]
906935 for i , batch in enumerate (dloader ):
907936 if not shuffle :
908937 assert all (
909938 torch .equal (x , y )
910- for x , y in zip (batch , (expected_2 if resume and length is not None else expected_1 )[i ])
939+ for x , y in zip (batch , (expected_3 if resume and length is not None else expected_1 )[i ])
911940 )
912941 elif not resume and length is not None :
913942 assert all (torch .equal (x , y ) for x , y in zip (batch , batches_1 [i ]))
0 commit comments