|
1 | 1 | import traceback |
2 | 2 | import warnings |
| 3 | +from itertools import product |
3 | 4 | from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union |
4 | 5 |
|
5 | 6 | import numpy as np |
@@ -179,31 +180,38 @@ def _test_model_inference_parametrized( |
179 | 180 | model: v0_5.ModelDescr, |
180 | 181 | weight_format: Optional[WeightsFormat], |
181 | 182 | devices: Optional[List[str]], |
182 | | - test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = { |
183 | | - (0, 2), |
184 | | - (1, 3), |
185 | | - (2, 1), |
186 | | - (3, 2), |
187 | | - }, |
188 | 183 | ) -> None: |
189 | | - if not test_cases: |
190 | | - return |
191 | | - |
192 | | - logger.info( |
193 | | - "Testing inference with {} different input tensor sizes", len(test_cases) |
194 | | - ) |
195 | | - |
196 | 184 | if not any( |
197 | 185 | isinstance(a.size, v0_5.ParameterizedSize) |
198 | 186 | for ipt in model.inputs |
199 | 187 | for a in ipt.axes |
200 | 188 | ): |
201 | 189 | # no parameterized sizes => set n=0 |
202 | | - test_cases = {(0, b) for _n, b in test_cases} |
| 190 | + ns: Set[v0_5.ParameterizedSize.N] = {0} |
| 191 | + else: |
| 192 | + ns = {0, 1, 2} |
203 | 193 |
|
204 | | - if not any(isinstance(a, v0_5.BatchAxis) for ipt in model.inputs for a in ipt.axes): |
205 | | - # no batch axis => set b=1 |
206 | | - test_cases = {(n, 1) for n, _b in test_cases} |
| 194 | + given_batch_sizes = { |
| 195 | + a.size |
| 196 | + for ipt in model.inputs |
| 197 | + for a in ipt.axes |
| 198 | + if isinstance(a, v0_5.BatchAxis) |
| 199 | + } |
| 200 | + if given_batch_sizes: |
| 201 | + batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} |
| 202 | + if not batch_sizes: |
| 203 | + # only arbitrary batch sizes |
| 204 | + batch_sizes = {1, 2} |
| 205 | + else: |
| 206 | + # no batch axis |
| 207 | + batch_sizes = {1} |
| 208 | + |
| 209 | + test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = { |
| 210 | + (n, b) for n, b in product(sorted(ns), sorted(batch_sizes)) |
| 211 | + } |
| 212 | + logger.info( |
| 213 | + "Testing inference with {} different input tensor sizes", len(test_cases) |
| 214 | + ) |
207 | 215 |
|
208 | 216 | def generate_test_cases(): |
209 | 217 | tested: Set[Hashable] = set() |
|
0 commit comments