|
1 | 1 | import collections |
2 | 2 | import os |
3 | | -import traceback |
4 | | -import warnings |
5 | 3 | from copy import deepcopy |
6 | 4 | from itertools import product |
7 | 5 | from pathlib import Path |
8 | | -from typing import Dict, List, Optional, OrderedDict, Sequence, Tuple, Union |
| 6 | +from typing import Dict, List, OrderedDict, Sequence, Tuple, Union |
9 | 7 |
|
10 | 8 | import imageio |
11 | 9 | import numpy as np |
|
14 | 12 |
|
15 | 13 | from bioimageio.core import load_resource_description |
16 | 14 | from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline |
| 15 | +from bioimageio.core.resource_io.nodes import ImplicitOutputShape, InputTensor, Model, OutputTensor |
| 16 | + |
17 | 17 |
|
18 | 18 | # |
19 | 19 | # utility functions for prediction |
20 | 20 | # |
21 | | -from bioimageio.core.resource_io.nodes import ( |
22 | | - ImplicitOutputShape, |
23 | | - InputTensor, |
24 | | - Model, |
25 | | - OutputTensor, |
26 | | - ResourceDescription, |
27 | | - URI, |
28 | | -) |
29 | | -from bioimageio.spec.model.raw_nodes import WeightsFormat |
30 | | -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription |
31 | | - |
32 | | - |
33 | 21 | def require_axes(im, axes): |
34 | 22 | is_volume = "z" in axes |
35 | 23 | # we assume images / volumes are loaded as one of |
@@ -482,73 +470,3 @@ def predict_images( |
482 | 470 | outp = [outp] |
483 | 471 |
|
484 | 472 | _predict_sample(prediction_pipeline, inp, outp, padding, tiling) |
485 | | - |
486 | | - |
487 | | -def test_model( |
488 | | - model_rdf: Union[URI, Path, str], |
489 | | - weight_format: Optional[WeightsFormat] = None, |
490 | | - devices: Optional[List[str]] = None, |
491 | | - decimal: int = 4, |
492 | | -) -> bool: |
493 | | - """Test whether the test output(s) of a model can be reproduced. |
494 | | -
|
495 | | - Returns True if the test passes, otherwise returns False and issues a warning. |
496 | | - """ |
497 | | - model = load_resource_description(model_rdf) |
498 | | - assert isinstance(model, Model) |
499 | | - summary = test_resource(model, weight_format=weight_format, devices=devices, decimal=decimal) |
500 | | - if summary["error"] is None: |
501 | | - return True |
502 | | - else: |
503 | | - warnings.warn(summary["error"]) |
504 | | - return False |
505 | | - |
506 | | - |
507 | | -def test_resource( |
508 | | - model_rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str], |
509 | | - *, |
510 | | - weight_format: Optional[WeightsFormat] = None, |
511 | | - devices: Optional[List[str]] = None, |
512 | | - decimal: int = 4, |
513 | | -): |
514 | | - """Test RDF dynamically |
515 | | -
|
516 | | - Returns summary dict with "error" and "traceback" key; summary["error"] is None if no errors were encountered. |
517 | | - """ |
518 | | - error: Optional[str] = None |
519 | | - tb: Optional = None |
520 | | - |
521 | | - try: |
522 | | - model = load_resource_description(model_rdf) |
523 | | - except Exception as e: |
524 | | - error = str(e) |
525 | | - tb = traceback.format_tb(e.__traceback__) |
526 | | - else: |
527 | | - if isinstance(model, Model): |
528 | | - try: |
529 | | - prediction_pipeline = create_prediction_pipeline( |
530 | | - bioimageio_model=model, devices=devices, weight_format=weight_format |
531 | | - ) |
532 | | - inputs = [np.load(str(in_path)) for in_path in model.test_inputs] |
533 | | - results = predict(prediction_pipeline, inputs) |
534 | | - if isinstance(results, (np.ndarray, xr.DataArray)): |
535 | | - results = [results] |
536 | | - |
537 | | - expected = [np.load(str(out_path)) for out_path in model.test_outputs] |
538 | | - if len(results) != len(expected): |
539 | | - error = ( |
540 | | - f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}" |
541 | | - ) |
542 | | - else: |
543 | | - for res, exp in zip(results, expected): |
544 | | - try: |
545 | | - np.testing.assert_array_almost_equal(res, exp, decimal=decimal) |
546 | | - except AssertionError as e: |
547 | | - error = f"Output and expected output disagree:\n {e}" |
548 | | - except Exception as e: |
549 | | - error = str(e) |
550 | | - tb = traceback.format_tb(e.__traceback__) |
551 | | - |
552 | | - # todo: add tests for non-model resources |
553 | | - |
554 | | - return {"error": error, "traceback": tb} |
0 commit comments