@@ -47,6 +47,9 @@ def __init__(self, hparams: dict):
4747
4848 self .ClassificationMetrics = ClassificationMetrics (self .n_classes , self .ignore_index )
4949
50+ self .training_step_outputs = []
51+ self .validation_step_outputs = []
52+
5053 def getLoss (self , out : ME .TensorField , past_labels : list ):
5154 loss = self .MOSLoss .compute_loss (out , past_labels )
5255 return loss
@@ -70,20 +73,20 @@ def training_step(self, batch: tuple, batch_idx, dataloader_index=0):
7073 self .get_step_confusion_matrix (out , past_labels , s ).detach ().cpu ()
7174 )
7275
76+ self .training_step_outputs .append (dict_confusion_matrix )
7377 torch .cuda .empty_cache ()
74- return {"loss" : loss , "dict_confusion_matrix" : dict_confusion_matrix }
7578
76- def training_epoch_end (self , training_step_outputs ):
77- list_dict_confusion_matrix = [
78- output ["dict_confusion_matrix" ] for output in training_step_outputs
79- ]
79+ return loss
80+
81+ def on_train_epoch_end (self ):
8082 for s in range (self .n_past_steps ):
8183 agg_confusion_matrix = torch .zeros (self .n_classes , self .n_classes )
82- for dict_confusion_matrix in list_dict_confusion_matrix :
84+ for dict_confusion_matrix in self . training_step_outputs :
8385 agg_confusion_matrix = agg_confusion_matrix .add (dict_confusion_matrix [s ])
8486 iou = self .ClassificationMetrics .getIoU (agg_confusion_matrix )
8587 self .log ("train_moving_iou_step{}" .format (s ), iou [2 ].item ())
8688
89+ self .training_step_outputs .clear ()
8790 torch .cuda .empty_cache ()
8891
8992 def validation_step (self , batch : tuple , batch_idx ):
@@ -101,17 +104,18 @@ def validation_step(self, batch: tuple, batch_idx):
101104 self .get_step_confusion_matrix (out , past_labels , s ).detach ().cpu ()
102105 )
103106
107+ self .validation_step_outputs .append (dict_confusion_matrix )
104108 torch .cuda .empty_cache ()
105- return dict_confusion_matrix
106109
107- def validation_epoch_end (self , validation_step_outputs ):
110+ def on_validation_epoch_end (self ):
108111 for s in range (self .n_past_steps ):
109112 agg_confusion_matrix = torch .zeros (self .n_classes , self .n_classes )
110- for dict_confusion_matrix in validation_step_outputs :
113+ for dict_confusion_matrix in self . validation_step_outputs :
111114 agg_confusion_matrix = agg_confusion_matrix .add (dict_confusion_matrix [s ])
112115 iou = self .ClassificationMetrics .getIoU (agg_confusion_matrix )
113116 self .log ("val_moving_iou_step{}" .format (s ), iou [2 ].item ())
114117
118+ self .validation_step_outputs .clear ()
115119 torch .cuda .empty_cache ()
116120
117121 def predict_step (self , batch : tuple , batch_idx : int , dataloader_idx : int = None ):
@@ -163,8 +167,8 @@ def get_step_confusion_matrix(self, out, past_labels, step):
163167 t = round (- step * self .dt_prediction , 3 )
164168 mask = out .coordinates [:, - 1 ].isclose (torch .tensor (t ))
165169 pred_logits = out .features [mask ].detach ().cpu ()
166- gt_labels = torch .cat (past_labels , dim = 0 ). detach (). cpu ()
167- gt_labels = gt_labels [mask ][:, 0 ]
170+ gt_labels = torch .cat (past_labels , dim = 0 )
171+ gt_labels = gt_labels [mask ][:, 0 ]. detach (). cpu ()
168172 confusion_matrix = self .ClassificationMetrics .compute_confusion_matrix (
169173 pred_logits , gt_labels
170174 )
0 commit comments