Skip to content

Commit 6757a22

Browse files
finalelementpre-commit-ci[bot]SachidanandAlle
authored
Dimension mismatch fixed to use MONAI Variance Metric function (#995)
* Dimension mismatch fixed to use MONAI Variance Metric function Signed-off-by: Vishwesh Nath <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vishwesh Nath <[email protected]> Co-authored-by: Vishwesh Nath <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: SACHIDANAND ALLE <[email protected]>
1 parent c6d8dcd commit 6757a22

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

monailabel/tasks/scoring/epistemic_v2.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def variance_volume(self, vol_input, ignore_nans=True):
9191
variance = np.sum(vari, axis=0)
9292
else:
9393
variance_metric = VarianceMetric(threshold=0.0005, spatial_map=True, scalar_reduction="sum")
94-
variance = variance_metric(vol_input).cpu().detach().numpy()
94+
variance = variance_metric(vol_input)
9595

9696
if self.dimension == 3:
9797
variance = np.expand_dims(variance, axis=0)
@@ -181,10 +181,13 @@ def run_scoring(self, image_id, simulation_size, model_ts, datastore):
181181

182182
accum = torch.stack(accum_unl_outputs)
183183
accum = torch.squeeze(accum)
184-
if self.dimension == 3:
185-
accum = accum[:, 1:, :, :, :] if len(accum.shape) > 4 else accum
186-
else:
187-
accum = accum[:, 1:, :, :] if len(accum.shape) > 3 else accum
184+
185+
# Accum Expected shape for 2D images is (N, C, H, W) for 3D (N, C, H, W, D)
186+
# To handle cases where only a single class of segmentation is present, an extra dimension is added
187+
if self.dimension == 2 and len(accum.shape) == 3:
188+
accum = torch.unsqueeze(accum, dim=1)
189+
elif self.dimension == 3 and len(accum.shape) == 4:
190+
accum = torch.unsqueeze(accum, dim=1)
188191

189192
entropy = self.variance_volume(accum) if self.use_variance else self.entropy_volume(accum)
190193
entropy = float(np.nanmean(entropy))

0 commit comments

Comments
 (0)