|
1 | 1 | import traceback |
2 | 2 | from pathlib import Path |
3 | | -from typing import List, Optional, Union |
| 3 | +from typing import List, Optional, Tuple, Union |
4 | 4 |
|
| 5 | +import numpy |
5 | 6 | import numpy as np |
6 | 7 | import xarray as xr |
| 8 | +from marshmallow import ValidationError |
7 | 9 |
|
8 | 10 | from bioimageio.core import load_resource_description |
9 | 11 | from bioimageio.core.prediction import predict |
10 | 12 | from bioimageio.core.prediction_pipeline import create_prediction_pipeline |
11 | | -from bioimageio.core.resource_io.nodes import Model, ResourceDescription, URI |
| 13 | +from bioimageio.core.resource_io.nodes import ( |
| 14 | + ImplicitOutputShape, |
| 15 | + Model, |
| 16 | + ParametrizedInputShape, |
| 17 | + ResourceDescription, |
| 18 | + URI, |
| 19 | +) |
12 | 20 | from bioimageio.spec.model.raw_nodes import WeightsFormat |
13 | 21 | from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription |
14 | 22 |
|
@@ -45,32 +53,93 @@ def test_resource( |
45 | 53 | tb: Optional = None |
46 | 54 |
|
47 | 55 | try: |
48 | | - model = load_resource_description(model_rdf) |
| 56 | + rd = load_resource_description(model_rdf) |
49 | 57 | except Exception as e: |
50 | 58 | error = str(e) |
51 | 59 | tb = traceback.format_tb(e.__traceback__) |
52 | 60 | else: |
53 | | - if isinstance(model, Model): |
| 61 | + if isinstance(rd, Model): |
| 62 | + model = rd |
54 | 63 | try: |
55 | | - prediction_pipeline = create_prediction_pipeline( |
56 | | - bioimageio_model=model, devices=devices, weight_format=weight_format |
57 | | - ) |
58 | 64 | inputs = [np.load(str(in_path)) for in_path in model.test_inputs] |
59 | | - results = predict(prediction_pipeline, inputs) |
60 | | - if isinstance(results, (np.ndarray, xr.DataArray)): |
61 | | - results = [results] |
62 | | - |
63 | 65 | expected = [np.load(str(out_path)) for out_path in model.test_outputs] |
| 66 | + |
| 67 | + # check if test data shapes match their description |
| 68 | + input_shapes = [ipt.shape for ipt in inputs] |
| 69 | + output_shapes = [out.shape for out in expected] |
| 70 | + |
| 71 | + def input_shape_is_valid(shape: Tuple[int, ...], shape_spec) -> bool: |
| 72 | + if isinstance(shape_spec, list): |
| 73 | + if shape != tuple(shape_spec): |
| 74 | + return False |
| 75 | + elif isinstance(shape_spec, ParametrizedInputShape): |
| 76 | + assert len(shape_spec.min) == len(shape_spec.step) |
| 77 | + if len(shape) != len(shape_spec.min): |
| 78 | + return False |
| 79 | + |
| 80 | + valid_shape = numpy.array(shape_spec.min) |
| 81 | + step = numpy.array(shape_spec.step) |
| 82 | + if (step == 0).all(): |
| 83 | + return shape == tuple(valid_shape) |
| 84 | + |
| 85 | + shape = numpy.array(shape) |
| 86 | + while (shape <= valid_shape).all(): |
| 87 | + if (shape == valid_shape).all(): |
| 88 | + break |
| 89 | + |
| 90 | + shape += step |
| 91 | + else: |
| 92 | + return False |
| 93 | + |
| 94 | + else: |
| 95 | + raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}") |
| 96 | + |
| 97 | + return True |
| 98 | + |
| 99 | + assert len(inputs) == len(model.inputs) # should be checked by validation |
| 100 | + input_shapes = {} |
| 101 | + for idx, (ipt, ipt_spec) in enumerate(zip(inputs, model.inputs)): |
| 102 | + if not input_shape_is_valid(ipt, ipt_spec.shape): |
| 103 | + raise ValidationError( |
| 104 | + f"Shape of test input {idx} '{ipt_spec.name}' does not match " |
| 105 | + f"input shape description: {ipt_spec.shape}" |
| 106 | + ) |
| 107 | + input_shapes[ipt_spec.name] = ipt.shape |
| 108 | + |
| 109 | + def output_shape_is_valid(shape: Tuple[int, ...], shape_spec) -> bool: |
| 110 | + if isinstance(shape_spec, list): |
| 111 | + return shape == tuple(shape_spec) |
| 112 | + elif isinstance(shape_spec, ImplicitOutputShape): |
| 113 | + ipt_shape = numpy.array(input_shapes[shape_spec.reference_tensor]) |
| 114 | + scale = numpy.array(shape_spec.scale) |
| 115 | + offset = numpy.array(shape_spec.offset) |
| 116 | + exp_shape = numpy.round_(ipt_shape * scale) + 2 * offset |
| 117 | + |
| 118 | + return shape == tuple(exp_shape) |
| 119 | + |
| 120 | + assert len(expected) == len(model.outputs) # should be checked by validation |
| 121 | + for idx, (out, out_spec) in enumerate(zip(expected, model.outputs)): |
| 122 | + if not output_shape_is_valid(out, out_spec.shape): |
| 123 | + error = (error or "") + ( |
| 124 | + f"Shape of test output {idx} '{out_spec.name}' does not match " |
| 125 | + f"output shape description: {out_spec.shape}.\n" |
| 126 | + ) |
| 127 | + |
| 128 | + with create_prediction_pipeline( |
| 129 | + bioimageio_model=model, devices=devices, weight_format=weight_format |
| 130 | + ) as prediction_pipeline: |
| 131 | + results = predict(prediction_pipeline, inputs) |
| 132 | + |
64 | 133 | if len(results) != len(expected): |
65 | | - error = ( |
| 134 | + error = (error or "") + ( |
66 | 135 | f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}" |
67 | 136 | ) |
68 | 137 | else: |
69 | 138 | for res, exp in zip(results, expected): |
70 | 139 | try: |
71 | 140 | np.testing.assert_array_almost_equal(res, exp, decimal=decimal) |
72 | 141 | except AssertionError as e: |
73 | | - error = f"Output and expected output disagree:\n {e}" |
| 142 | + error = (error or "") + f"Output and expected output disagree:\n {e}" |
74 | 143 | except Exception as e: |
75 | 144 | error = str(e) |
76 | 145 | tb = traceback.format_tb(e.__traceback__) |
|
0 commit comments