@@ -93,6 +93,9 @@ def parse_vision_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndar
9393 self .parse_binary_crossentropy ('lane_lines_prob' , outs )
9494 self .parse_categorical_crossentropy ('desire_pred' , outs , out_shape = (ModelConstants .DESIRE_PRED_LEN ,ModelConstants .DESIRE_PRED_WIDTH ))
9595 self .parse_binary_crossentropy ('meta' , outs )
96+ self .parse_binary_crossentropy ('lead_prob' , outs )
97+ self .parse_mdn ('lead' , outs , in_N = ModelConstants .LEAD_MHP_N , out_N = ModelConstants .LEAD_MHP_SELECTION ,
98+ out_shape = (ModelConstants .LEAD_TRAJ_LEN ,ModelConstants .LEAD_WIDTH ))
9699 return outs
97100
98101 def parse_policy_outputs (self , outs : dict [str , np .ndarray ]) -> dict [str , np .ndarray ]:
@@ -103,9 +106,6 @@ def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndar
103106 if 'desired_curvature' in outs :
104107 self .parse_mdn ('desired_curvature' , outs , in_N = 0 , out_N = 0 , out_shape = (ModelConstants .DESIRED_CURV_WIDTH ,))
105108 self .parse_categorical_crossentropy ('desire_state' , outs , out_shape = (ModelConstants .DESIRE_PRED_WIDTH ,))
106- self .parse_binary_crossentropy ('lead_prob' , outs )
107- self .parse_mdn ('lead' , outs , in_N = ModelConstants .LEAD_MHP_N , out_N = ModelConstants .LEAD_MHP_SELECTION ,
108- out_shape = (ModelConstants .LEAD_TRAJ_LEN ,ModelConstants .LEAD_WIDTH ))
109109 return outs
110110
111111 def parse_outputs (self , outs : dict [str , np .ndarray ]) -> dict [str , np .ndarray ]:
0 commit comments