1212import pandas as pd
1313
1414from ..dataset import DatasetTemplate
15- from .argo2_utils .so3 import yaw_to_quat
15+ from .argo2_utils .so3 import yaw_to_quat , quat_to_yaw
1616from .argo2_utils .constants import LABEL_ATTR
1717
1818
1919def process_single_segment (segment_path , split , info_list , ts2idx , output_dir , save_bin ):
2020 test_mode = 'test' in split
2121 if not test_mode :
22- segment_anno = read_feather (osp .join (segment_path , 'annotations.feather' ))
22+ segment_anno = read_feather (Path ( osp .join (segment_path , 'annotations.feather' ) ))
2323 segname = segment_path .split ('/' )[- 1 ]
2424
2525 frame_path_list = os .listdir (osp .join (segment_path , 'sensors/lidar/' ))
@@ -70,17 +70,7 @@ def process_and_save_frame(frame_path, frame_anno, ts2idx, segname, output_dir,
7070 cuboid_params = torch .from_numpy (cuboid_params )
7171 yaw = quat_to_yaw (cuboid_params [:, - 4 :])
7272 xyz = cuboid_params [:, :3 ]
73- wlh = cuboid_params [:, [4 , 3 , 5 ]]
74-
75- yaw = - yaw - 0.5 * np .pi
76-
77- while (yaw < - np .pi ).any ():
78- yaw [yaw < - np .pi ] += 2 * np .pi
79-
80- while (yaw > np .pi ).any ():
81- yaw [yaw > np .pi ] -= 2 * np .pi
82-
83- # bbox = torch.cat([xyz, wlh, yaw.unsqueeze(1)], dim=1).numpy()
73+ lwh = cuboid_params [:, [3 , 4 , 5 ]]
8474
8575 cat = frame_anno ['category' ].to_numpy ().tolist ()
8676 cat = [c .lower ().capitalize () for c in cat ]
@@ -93,7 +83,7 @@ def process_and_save_frame(frame_path, frame_anno, ts2idx, segname, output_dir,
9383 annos ['truncated' ] = np .zeros (num_obj , dtype = np .float64 )
9484 annos ['occluded' ] = np .zeros (num_obj , dtype = np .int64 )
9585 annos ['alpha' ] = - 10 * np .ones (num_obj , dtype = np .float64 )
96- annos ['dimensions' ] = wlh .numpy ().astype (np .float64 )
86+ annos ['dimensions' ] = lwh .numpy ().astype (np .float64 )
9787 annos ['location' ] = xyz .numpy ().astype (np .float64 )
9888 annos ['rotation_y' ] = yaw .numpy ().astype (np .float64 )
9989 annos ['index' ] = np .arange (num_obj , dtype = np .int32 )
@@ -111,7 +101,7 @@ def process_and_save_frame(frame_path, frame_anno, ts2idx, segname, output_dir,
111101
112102
113103def save_point_cloud (frame_path , save_path ):
114- lidar = read_feather (frame_path )
104+ lidar = read_feather (Path ( frame_path ) )
115105 lidar = lidar .loc [:, ['x' , 'y' , 'z' , 'intensity' ]].to_numpy ().astype (np .float32 )
116106 lidar .tofile (save_path )
117107
@@ -375,9 +365,9 @@ def format_results(self,
375365 assert len (self .argo2_infos ) == len (outputs )
376366 num_samples = len (outputs )
377367 print ('\n Got {} samples' .format (num_samples ))
378-
368+
379369 serialized_dts_list = []
380-
370+
381371 print ('\n Convert predictions to Argoverse 2 format' )
382372 for i in range (num_samples ):
383373 out_i = outputs [i ]
@@ -394,7 +384,7 @@ def format_results(self,
394384 serialized_dts ["timestamp_ns" ] = int (ts )
395385 serialized_dts ["category" ] = category
396386 serialized_dts_list .append (serialized_dts )
397-
387+
398388 dts = (
399389 pd .concat (serialized_dts_list )
400390 .set_index (["log_id" , "timestamp_ns" ])
@@ -411,19 +401,13 @@ def format_results(self,
411401
412402 dts = dts .set_index (["log_id" , "timestamp_ns" ]).sort_index ()
413403
414- return dts
415-
404+ return dts
405+
416406 def lidar_box_to_argo2 (self , boxes ):
417407 boxes = torch .Tensor (boxes )
418408 cnt_xyz = boxes [:, :3 ]
419- lwh = boxes [:, [4 , 3 , 5 ]]
420- yaw = boxes [:, 6 ] #- np.pi/2
421-
422- yaw = - yaw - 0.5 * np .pi
423- while (yaw < - np .pi ).any ():
424- yaw [yaw < - np .pi ] += 2 * np .pi
425- while (yaw > np .pi ).any ():
426- yaw [yaw > np .pi ] -= 2 * np .pi
409+ lwh = boxes [:, [3 , 4 , 5 ]]
410+ yaw = boxes [:, 6 ]
427411
428412 quat = yaw_to_quat (yaw )
429413 argo_cuboid = torch .cat ([cnt_xyz , lwh , quat ], dim = 1 )
@@ -470,7 +454,7 @@ def evaluation(self,
470454 dts = self .format_results (results , class_names , pklfile_prefix , submission_prefix )
471455 argo2_root = self .root_path
472456 val_anno_path = osp .join (argo2_root , 'val_anno.feather' )
473- gts = read_feather (val_anno_path )
457+ gts = read_feather (Path ( val_anno_path ) )
474458 gts = gts .set_index (["log_id" , "timestamp_ns" ]).sort_values ("category" )
475459
476460 valid_uuids_gts = gts .index .tolist ()
@@ -508,6 +492,13 @@ def parse_config():
508492 args = parser .parse_args ()
509493 return args
510494
495+ def main (seg_path_list , seg_split_list , info_list , ts2idx , output_dir , save_bin , token , num_process ):
496+ for seg_i , seg_path in enumerate (seg_path_list ):
497+ if seg_i % num_process != token :
498+ continue
499+ print (f'processing segment: { seg_i } /{ len (seg_path_list )} ' )
500+ split = seg_split_list [seg_i ]
501+ process_single_segment (seg_path , split , info_list , ts2idx , output_dir , save_bin )
511502
512503if __name__ == '__main__' :
513504 args = parse_config ()
@@ -559,4 +550,5 @@ def parse_config():
559550 seg_anno_list .append (seg_anno )
560551
561552 gts = pd .concat (seg_anno_list ).reset_index ()
562- gts .to_feather (val_seg_path_list )
553+ gts .to_feather (save_feather_path )
554+
0 commit comments