@@ -119,7 +119,12 @@ def _build_model(self) -> nn.Module:
119119 """
120120
121121 def _customize_inputs (self , entity : SegBatchDataEntity ) -> dict [str , Any ]:
122- mode = "loss" if self .training else "predict"
122+ if self .training :
123+ mode = "loss"
124+ elif self .explain_mode :
125+ mode = "explain"
126+ else :
127+ mode = "predict"
123128
124129 if self .train_type == OTXTrainType .SEMI_SUPERVISED and mode == "loss" :
125130 if not isinstance (entity , dict ):
@@ -155,6 +160,16 @@ def _customize_outputs(
155160 losses [k ] = v
156161 return losses
157162
163+ if self .explain_mode :
164+ return SegBatchPredEntity (
165+ batch_size = len (outputs ["preds" ]),
166+ images = inputs .images ,
167+ imgs_info = inputs .imgs_info ,
168+ scores = [],
169+ masks = outputs ["preds" ],
170+ feature_vector = outputs ["feature_vector" ],
171+ )
172+
158173 return SegBatchPredEntity (
159174 batch_size = len (outputs ),
160175 images = inputs .images ,
@@ -199,14 +214,24 @@ def _exporter(self) -> OTXModelExporter:
199214 swap_rgb = False ,
200215 via_onnx = False ,
201216 onnx_export_configuration = None ,
202- output_names = None ,
217+ output_names = [ "preds" , "feature_vector" ] if self . explain_mode else None ,
203218 )
204219
205220 def _convert_pred_entity_to_compute_metric (
206221 self ,
207222 preds : SegBatchPredEntity ,
208223 inputs : SegBatchDataEntity ,
209224 ) -> MetricInput :
225+ """Convert prediction and input entities to a format suitable for metric computation.
226+
227+ Args:
228+ preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks.
229+ inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks.
230+
231+ Returns:
232+ MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys
233+ corresponding to the predicted and target masks for metric evaluation.
234+ """
210235 return [
211236 {
212237 "preds" : pred_mask ,
@@ -228,8 +253,26 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
228253
229254 def forward_for_tracing (self , image : Tensor ) -> Tensor | dict [str , Tensor ]:
230255 """Model forward function used for the model tracing during model exportation."""
231- raw_outputs = self .model (inputs = image , mode = "tensor" )
232- return torch .softmax (raw_outputs , dim = 1 )
256+ if self .explain_mode :
257+ outputs = self .model (inputs = image , mode = "explain" )
258+ outputs ["preds" ] = torch .softmax (outputs ["preds" ], dim = 1 )
259+ return outputs
260+
261+ outputs = self .model (inputs = image , mode = "tensor" )
262+ return torch .softmax (outputs , dim = 1 )
263+
264+ def forward_explain (self , inputs : SegBatchDataEntity ) -> SegBatchPredEntity :
265+ """Model forward explain function."""
266+ outputs = self .model (inputs = inputs .images , mode = "explain" )
267+
268+ return SegBatchPredEntity (
269+ batch_size = len (outputs ["preds" ]),
270+ images = inputs .images ,
271+ imgs_info = inputs .imgs_info ,
272+ scores = [],
273+ masks = outputs ["preds" ],
274+ feature_vector = outputs ["feature_vector" ],
275+ )
233276
234277 def get_dummy_input (self , batch_size : int = 1 ) -> SegBatchDataEntity :
235278 """Returns a dummy input for semantic segmentation model."""
@@ -308,32 +351,34 @@ def _customize_outputs(
308351 outputs : list [ImageResultWithSoftPrediction ],
309352 inputs : SegBatchDataEntity ,
310353 ) -> SegBatchPredEntity | OTXBatchLossEntity :
311- if outputs and outputs [0 ].saliency_map .size != 1 :
312- predicted_s_maps = [out .saliency_map for out in outputs ]
313- predicted_f_vectors = [out .feature_vector for out in outputs ]
314- return SegBatchPredEntity (
315- batch_size = len (outputs ),
316- images = inputs .images ,
317- imgs_info = inputs .imgs_info ,
318- scores = [],
319- masks = [tv_tensors .Mask (mask .resultImage , device = self .device ) for mask in outputs ],
320- saliency_map = predicted_s_maps ,
321- feature_vector = predicted_f_vectors ,
322- )
323-
354+ masks = [tv_tensors .Mask (mask .resultImage , device = self .device ) for mask in outputs ]
355+ predicted_f_vectors = (
356+ [out .feature_vector for out in outputs ] if outputs and outputs [0 ].feature_vector .size != 1 else []
357+ )
324358 return SegBatchPredEntity (
325359 batch_size = len (outputs ),
326360 images = inputs .images ,
327361 imgs_info = inputs .imgs_info ,
328362 scores = [],
329- masks = [tv_tensors .Mask (mask .resultImage , device = self .device ) for mask in outputs ],
363+ masks = masks ,
364+ feature_vector = predicted_f_vectors ,
330365 )
331366
332367 def _convert_pred_entity_to_compute_metric (
333368 self ,
334369 preds : SegBatchPredEntity ,
335370 inputs : SegBatchDataEntity ,
336371 ) -> MetricInput :
372+ """Convert prediction and input entities to a format suitable for metric computation.
373+
374+ Args:
375+ preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks.
376+ inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks.
377+
378+ Returns:
379+ MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys
380+ corresponding to the predicted and target masks for metric evaluation.
381+ """
337382 return [
338383 {
339384 "preds" : pred_mask ,
0 commit comments