Skip to content

Commit 81ea7db

Browse files
committed
clean up test case creation
and respect fixed batch sizes
1 parent 9297bbf commit 81ea7db

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import traceback
22
import warnings
3+
from itertools import product
34
from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union
45

56
import numpy as np
@@ -179,31 +180,38 @@ def _test_model_inference_parametrized(
179180
model: v0_5.ModelDescr,
180181
weight_format: Optional[WeightsFormat],
181182
devices: Optional[List[str]],
182-
test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = {
183-
(0, 2),
184-
(1, 3),
185-
(2, 1),
186-
(3, 2),
187-
},
188183
) -> None:
189-
if not test_cases:
190-
return
191-
192-
logger.info(
193-
"Testing inference with {} different input tensor sizes", len(test_cases)
194-
)
195-
196184
if not any(
197185
isinstance(a.size, v0_5.ParameterizedSize)
198186
for ipt in model.inputs
199187
for a in ipt.axes
200188
):
201189
# no parameterized sizes => set n=0
202-
test_cases = {(0, b) for _n, b in test_cases}
190+
ns: Set[v0_5.ParameterizedSize.N] = {0}
191+
else:
192+
ns = {0, 1, 2}
203193

204-
if not any(isinstance(a, v0_5.BatchAxis) for ipt in model.inputs for a in ipt.axes):
205-
# no batch axis => set b=1
206-
test_cases = {(n, 1) for n, _b in test_cases}
194+
given_batch_sizes = {
195+
a.size
196+
for ipt in model.inputs
197+
for a in ipt.axes
198+
if isinstance(a, v0_5.BatchAxis)
199+
}
200+
if given_batch_sizes:
201+
batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
202+
if not batch_sizes:
203+
# only arbitrary batch sizes
204+
batch_sizes = {1, 2}
205+
else:
206+
# no batch axis
207+
batch_sizes = {1}
208+
209+
test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = {
210+
(n, b) for n, b in product(sorted(ns), sorted(batch_sizes))
211+
}
212+
logger.info(
213+
"Testing inference with {} different input tensor sizes", len(test_cases)
214+
)
207215

208216
def generate_test_cases():
209217
tested: Set[Hashable] = set()

0 commit comments

Comments
 (0)