|
8 | 8 | tensorflow1_models = ["FruNet_model"] |
9 | 9 | tensorflow2_models = [] |
10 | 10 | keras_models = ["FruNet_model"] |
| 11 | +tensorflow_js_models = ["FruNet_model"] |
11 | 12 |
|
12 | 13 | model_sources = { |
13 | 14 | "unet2d_nuclei_broad_model": ( |
@@ -61,11 +62,13 @@ def pytest_configure(): |
61 | 62 | if not pytest.skip_onnx: |
62 | 63 | load_packages |= set(onnx_models) |
63 | 64 |
|
64 | | - if not pytest.skip_tensorflow and pytest.tf_major_version == 1: |
65 | | - load_packages |= set(tensorflow1_models) |
66 | | - |
67 | | - if not pytest.skip_tensorflow and pytest.tf_major_version == 2: |
68 | | - load_packages |= set(tensorflow2_models) |
| 65 | + if not pytest.skip_tensorflow: |
| 66 | + load_packages |= set(keras_models) |
| 67 | + load_packages |= set(tensorflow_js_models) |
| 68 | + if pytest.tf_major_version == 1: |
| 69 | + load_packages |= set(tensorflow1_models) |
| 70 | + elif pytest.tf_major_version == 2: |
| 71 | + load_packages |= set(tensorflow2_models) |
69 | 72 |
|
70 | 73 | pytest.model_packages = {name: export_resource_package(model_sources[name]) for name in load_packages} |
71 | 74 |
|
@@ -105,18 +108,22 @@ def any_onnx_model(request): |
105 | 108 | return pytest.model_packages[request.param] |
106 | 109 |
|
107 | 110 |
|
108 | | -@pytest.fixture(params=tensorflow1_models) |
109 | | -def any_tensorflow1_model(request): |
110 | | - return pytest.model_packages[request.param] |
| 111 | +@pytest.fixture(params=set(tensorflow1_models) | set(tensorflow2_models)) |
| 112 | +def any_tensorflow_model(request): |
| 113 | + name = request.param |
| 114 | + if (pytest.tf_major_version == 1 and name in tensorflow1_models) or ( |
| 115 | + pytest.tf_major_version == 2 and name in tensorflow2_models |
| 116 | + ): |
| 117 | + return pytest.model_packages[name] |
111 | 118 |
|
112 | 119 |
|
113 | | -@pytest.fixture(params=tensorflow2_models) |
114 | | -def any_tensorflow2_model(request): |
| 120 | +@pytest.fixture(params=keras_models) |
| 121 | +def any_keras_model(request): |
115 | 122 | return pytest.model_packages[request.param] |
116 | 123 |
|
117 | 124 |
|
118 | | -@pytest.fixture(params=keras_models) |
119 | | -def any_keras_model(request): |
| 125 | +@pytest.fixture(params=tensorflow_js_models) |
| 126 | +def any_tensorflow_js_model(request): |
120 | 127 | return pytest.model_packages[request.param] |
121 | 128 |
|
122 | 129 |
|
|
0 commit comments