Skip to content

Commit e621b86

Browse files
add test for MultiMacenkoNormalizer torch
1 parent 6c88302 commit e621b86

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

tests/test_torch.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,40 @@ def test_macenko_torch():
5757
# assess whether the normalized images are identical across backends
5858
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)
5959

60+
def test_multitarget_macenko_torch():
61+
size = 1024
62+
curr_file_path = os.path.dirname(os.path.realpath(__file__))
63+
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
64+
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))
65+
66+
# setup preprocessing and preprocess image to be normalized
67+
T = transforms.Compose([
68+
transforms.ToTensor(),
69+
transforms.Lambda(lambda x: x * 255)
70+
])
71+
target = T(target)
72+
t_to_transform = T(to_transform)
73+
74+
# initialize normalizers for each backend and fit to target image
75+
single_normalizer = torchstain.normalizers.MacenkoNormalizer(backend="torch")
76+
single_normalizer.fit(target)
77+
78+
multi_normalizer = torchstain.normalizers.MultiMacenkoNormalizer(backend="torch", norm_mode="avg-post")
79+
multi_normalizer.fit([target, target, target])
80+
81+
82+
# transform
83+
result_single, _, _ = single_normalizer.normalize(I=t_to_transform, stains=True)
84+
result_multi, _, _ = multi_normalizer.normalize(I=t_to_transform, stains=True)
85+
86+
# convert to numpy and set dtype
87+
result_single = result_single.numpy().astype("float32") / 255.
88+
result_multi = result_multi.numpy().astype("float32") / 255.
89+
90+
# assess whether the normalized images are identical across backends
91+
np.testing.assert_almost_equal(result_single.flatten(), result_multi.flatten(), decimal=2, verbose=True)
92+
93+
6094
def test_reinhard_torch():
6195
size = 1024
6296
curr_file_path = os.path.dirname(os.path.realpath(__file__))

0 commit comments

Comments
 (0)