Skip to content

Commit 46dc2ec

Browse files
authored
update integration tests metatensor (#4501)
* update integration Signed-off-by: Wenqi Li <[email protected]> * update integration tests Signed-off-by: Wenqi Li <[email protected]>
1 parent 36dc126 commit 46dc2ec

File tree

3 files changed

+12
-21
lines changed

3 files changed

+12
-21
lines changed

tests/test_integration_classification_2d.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
RandRotate,
3434
RandZoom,
3535
ScaleIntensity,
36-
ToTensor,
3736
Transpose,
3837
)
3938
from monai.utils import set_determinism
@@ -69,15 +68,12 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0",
6968
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True, dtype=np.float64),
7069
RandFlip(spatial_axis=0, prob=0.5),
7170
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
72-
ToTensor(),
7371
]
7472
)
7573
train_transforms.set_random_state(1234)
76-
val_transforms = Compose(
77-
[LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), ToTensor()]
78-
)
79-
y_pred_trans = Compose([ToTensor(), Activations(softmax=True)])
80-
y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=len(np.unique(train_y)))])
74+
val_transforms = Compose([LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity()])
75+
y_pred_trans = Compose([Activations(softmax=True)])
76+
y_trans = AsDiscrete(to_onehot=len(np.unique(train_y)))
8177
auc_metric = ROCAUCMetric()
8278

8379
# create train, val data loaders
@@ -132,7 +128,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0",
132128
acc_metric = acc_value.sum().item() / len(acc_value)
133129
# decollate prediction and label and execute post processing
134130
y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)]
135-
y = [y_trans(i) for i in decollate_batch(y)]
131+
y = [y_trans(i) for i in decollate_batch(y, detach=False)]
136132
# compute AUC
137133
auc_metric(y_pred, y)
138134
auc_value = auc_metric.aggregate()
@@ -153,7 +149,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0",
153149

154150
def run_inference_test(root_dir, test_x, test_y, device="cuda:0", num_workers=10):
155151
# define transforms for image and classification
156-
val_transforms = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), ToTensor()])
152+
val_transforms = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity()])
157153
val_ds = MedNISTDataset(test_x, test_y, val_transforms)
158154
val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers)
159155

tests/test_integration_segmentation_3d.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.utils.tensorboard import SummaryWriter
2222

2323
import monai
24-
from monai.data import MetaTensor, create_test_image_3d, decollate_batch
24+
from monai.data import create_test_image_3d, decollate_batch
2525
from monai.inferers import sliding_window_inference
2626
from monai.metrics import DiceMetric
2727
from monai.networks import eval_mode
@@ -31,7 +31,6 @@
3131
AsDiscrete,
3232
Compose,
3333
EnsureChannelFirstd,
34-
FromMetaTensord,
3534
LoadImaged,
3635
RandCropByPosNegLabeld,
3736
RandRotate90d,
@@ -40,7 +39,6 @@
4039
Spacingd,
4140
)
4241
from monai.utils import set_determinism
43-
from monai.utils.enums import PostFix
4442
from monai.visualize import plot_2d_or_3d_image
4543
from tests.testing_data.integration_answers import test_integration_value
4644
from tests.utils import DistTestCase, TimedCall, skip_if_quick
@@ -187,7 +185,6 @@ def run_inference_test(root_dir, device="cuda:0"):
187185
# resampling with align_corners=True or dtype=float64 will generate
188186
# slight different results between PyTorch 1.5 an 1.6
189187
Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
190-
FromMetaTensord(["img", "seg"]),
191188
ScaleIntensityd(keys="img"),
192189
]
193190
)
@@ -225,11 +222,10 @@ def run_inference_test(root_dir, device="cuda:0"):
225222
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
226223
# decollate prediction into a list
227224
val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
228-
val_meta = decollate_batch(val_data[PostFix.meta("img")])
229225
# compute metrics
230226
dice_metric(y_pred=val_outputs, y=val_labels)
231-
for img, meta in zip(val_outputs, val_meta): # save a decollated batch of files
232-
saver(MetaTensor(img, meta=meta))
227+
for img in val_outputs: # save a decollated batch of files
228+
saver(img)
233229

234230
return dice_metric.aggregate().item()
235231

tests/test_integration_sliding_window.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ignite.engine import Engine, Events
2020
from torch.utils.data import DataLoader
2121

22-
from monai.data import ImageDataset, MetaTensor, create_test_image_3d, decollate_batch
22+
from monai.data import ImageDataset, create_test_image_3d
2323
from monai.inferers import sliding_window_inference
2424
from monai.networks import eval_mode, predict_segmentation
2525
from monai.networks.nets import UNet
@@ -29,7 +29,7 @@
2929

3030

3131
def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"):
32-
ds = ImageDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False)
32+
ds = ImageDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=True)
3333
loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())
3434

3535
net = UNet(
@@ -47,9 +47,8 @@ def _sliding_window_processor(_engine, batch):
4747
return predict_segmentation(seg_probs)
4848

4949
def save_func(engine):
50-
meta_data = decollate_batch(engine.state.batch[2])
51-
for m, o in zip(meta_data, engine.state.output):
52-
saver(MetaTensor(o, meta=m))
50+
for m in engine.state.output:
51+
saver(m)
5352

5453
infer_engine = Engine(_sliding_window_processor)
5554
infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func)

0 commit comments

Comments
 (0)