@@ -1199,3 +1199,96 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
11991199        expected_to  =  (ep_idx  +  1 ) *  frames_per_episode 
12001200        assert  ep_metadata ["dataset_from_index" ] ==  expected_from 
12011201        assert  ep_metadata ["dataset_to_index" ] ==  expected_to 
1202+ 
1203+ 
1204+ def  test_frames_in_current_file_calculation (tmp_path , empty_lerobot_dataset_factory ):
1205+     """Regression test for bug where frames_in_current_file only counted frames from last episode instead of all frames in current file.""" 
1206+     features  =  {
1207+         "observation.state" : {"dtype" : "float32" , "shape" : (2 ,), "names" : ["x" , "y" ]},
1208+         "action" : {"dtype" : "float32" , "shape" : (2 ,), "names" : ["vx" , "vy" ]},
1209+     }
1210+ 
1211+     dataset  =  empty_lerobot_dataset_factory (root = tmp_path  /  "test" , features = features , use_videos = False )
1212+     dataset .meta .update_chunk_settings (data_files_size_in_mb = 100 )
1213+ 
1214+     assert  dataset ._current_file_start_frame  is  None 
1215+ 
1216+     frames_per_episode  =  10 
1217+     for  _  in  range (frames_per_episode ):
1218+         dataset .add_frame (
1219+             {
1220+                 "observation.state" : torch .randn (2 ),
1221+                 "action" : torch .randn (2 ),
1222+                 "task" : "task_0" ,
1223+             }
1224+         )
1225+     dataset .save_episode ()
1226+ 
1227+     assert  dataset ._current_file_start_frame  ==  0 
1228+     assert  dataset .meta .total_episodes  ==  1 
1229+     assert  dataset .meta .total_frames  ==  frames_per_episode 
1230+ 
1231+     for  _  in  range (frames_per_episode ):
1232+         dataset .add_frame (
1233+             {
1234+                 "observation.state" : torch .randn (2 ),
1235+                 "action" : torch .randn (2 ),
1236+                 "task" : "task_1" ,
1237+             }
1238+         )
1239+     dataset .save_episode ()
1240+ 
1241+     assert  dataset ._current_file_start_frame  ==  0 
1242+     assert  dataset .meta .total_episodes  ==  2 
1243+     assert  dataset .meta .total_frames  ==  2  *  frames_per_episode 
1244+ 
1245+     ep1_chunk  =  dataset .latest_episode ["data/chunk_index" ]
1246+     ep1_file  =  dataset .latest_episode ["data/file_index" ]
1247+     assert  ep1_chunk  ==  0 
1248+     assert  ep1_file  ==  0 
1249+ 
1250+     for  _  in  range (frames_per_episode ):
1251+         dataset .add_frame (
1252+             {
1253+                 "observation.state" : torch .randn (2 ),
1254+                 "action" : torch .randn (2 ),
1255+                 "task" : "task_2" ,
1256+             }
1257+         )
1258+     dataset .save_episode ()
1259+ 
1260+     assert  dataset ._current_file_start_frame  ==  0 
1261+     assert  dataset .meta .total_episodes  ==  3 
1262+     assert  dataset .meta .total_frames  ==  3  *  frames_per_episode 
1263+ 
1264+     ep2_chunk  =  dataset .latest_episode ["data/chunk_index" ]
1265+     ep2_file  =  dataset .latest_episode ["data/file_index" ]
1266+     assert  ep2_chunk  ==  0 
1267+     assert  ep2_file  ==  0 
1268+ 
1269+     dataset .finalize ()
1270+ 
1271+     from  lerobot .datasets .utils  import  load_episodes 
1272+ 
1273+     dataset .meta .episodes  =  load_episodes (dataset .root )
1274+     assert  dataset .meta .episodes  is  not   None 
1275+ 
1276+     for  ep_idx  in  range (3 ):
1277+         ep_metadata  =  dataset .meta .episodes [ep_idx ]
1278+         assert  ep_metadata ["data/chunk_index" ] ==  0 
1279+         assert  ep_metadata ["data/file_index" ] ==  0 
1280+ 
1281+         expected_from  =  ep_idx  *  frames_per_episode 
1282+         expected_to  =  (ep_idx  +  1 ) *  frames_per_episode 
1283+         assert  ep_metadata ["dataset_from_index" ] ==  expected_from 
1284+         assert  ep_metadata ["dataset_to_index" ] ==  expected_to 
1285+ 
1286+     loaded_dataset  =  LeRobotDataset (dataset .repo_id , root = dataset .root )
1287+     assert  len (loaded_dataset ) ==  3  *  frames_per_episode 
1288+     assert  loaded_dataset .meta .total_episodes  ==  3 
1289+     assert  loaded_dataset .meta .total_frames  ==  3  *  frames_per_episode 
1290+ 
1291+     for  idx  in  range (len (loaded_dataset )):
1292+         frame  =  loaded_dataset [idx ]
1293+         expected_ep  =  idx  //  frames_per_episode 
1294+         assert  frame ["episode_index" ].item () ==  expected_ep 
0 commit comments