Skip to content

Commit dd28cde

Browse files
authored
Merge branch 'develop' into add-extra-foundation-models
2 parents a42cd0d + 95e70fa commit dd28cde

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

tests/models/test_arch_mapde.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,12 @@ def test_functionality(remote_sample: Callable) -> None:
4848
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
4949
output = model.postproc(output[0])
5050
assert np.all(output[0:2] == [[19, 171], [53, 89]])
51+
52+
53+
def test_multiclass_output() -> None:
54+
"""Test the architecture for multi-class output."""
55+
multiclass_model = MapDe(num_input_channels=3, num_classes=3)
56+
test_input = torch.rand((1, 3, 252, 252))
57+
58+
output = multiclass_model(test_input)
59+
assert output.shape == (1, 3, 252, 252)

tiatoolbox/models/architecture/mapde.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,11 @@ def __init__(
199199
dtype=np.float32,
200200
)
201201

202-
dist_filter = np.expand_dims(dist_filter, axis=(0, 1)) # NCHW
202+
# For conv2d, filter shape = (out_channels, in_channels//groups, H, W)
203+
dist_filter = np.expand_dims(dist_filter, axis=(0, 1))
203204
dist_filter = np.repeat(dist_filter, repeats=num_classes * 2, axis=1)
205+
# Need to repeat for out_channels
206+
dist_filter = np.repeat(dist_filter, repeats=num_classes, axis=0)
204207

205208
self.min_distance = min_distance
206209
self.threshold_abs = threshold_abs

0 commit comments

Comments
 (0)