|
1 | 1 | import importlib |
2 | 2 | import os |
3 | | -import warnings |
4 | 3 |
|
5 | 4 | import yaml |
6 | 5 |
|
|
10 | 9 | from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401 |
11 | 10 | from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401 |
12 | 11 | from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler |
| 12 | +from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401 |
13 | 13 | from hls4ml.converters.onnx_to_hls import parse_onnx_model # noqa: F401 |
| 14 | +from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler |
| 15 | +from hls4ml.converters.pytorch_to_hls import ( # noqa: F401 |
| 16 | + get_supported_pytorch_layers, |
| 17 | + pytorch_to_hls, |
| 18 | + register_pytorch_layer_handler, |
| 19 | +) |
14 | 20 | from hls4ml.model import ModelGraph |
15 | 21 | from hls4ml.utils.config import create_config |
| 22 | +from hls4ml.utils.dependency import requires |
16 | 23 | from hls4ml.utils.symbolic_utils import LUTFunction |
17 | 24 |
|
18 | | -# ----------Make converters available if the libraries can be imported----------# |
19 | | -try: |
20 | | - from hls4ml.converters.pytorch_to_hls import ( # noqa: F401 |
21 | | - get_supported_pytorch_layers, |
22 | | - pytorch_to_hls, |
23 | | - register_pytorch_layer_handler, |
24 | | - ) |
25 | | - |
26 | | - __pytorch_enabled__ = True |
27 | | -except ImportError: |
28 | | - warnings.warn("WARNING: Pytorch converter is not enabled!", stacklevel=1) |
29 | | - __pytorch_enabled__ = False |
30 | | - |
31 | | -try: |
32 | | - from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401 |
33 | | - from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler |
34 | | - |
35 | | - __onnx_enabled__ = True |
36 | | -except ImportError: |
37 | | - warnings.warn("WARNING: ONNX converter is not enabled!", stacklevel=1) |
38 | | - __onnx_enabled__ = False |
39 | | - |
40 | 25 | # ----------Layer handling register----------# |
41 | 26 | model_types = ['keras', 'pytorch', 'onnx'] |
42 | 27 |
|
|
51 | 36 | # and has 'handles' attribute |
52 | 37 | # and is defined in this module (i.e., not imported) |
53 | 38 | if callable(func) and hasattr(func, 'handles') and func.__module__ == lib.__name__: |
54 | | - for layer in func.handles: |
| 39 | + for layer in func.handles: # type: ignore |
55 | 40 | if model_type == 'keras': |
56 | 41 | register_keras_layer_handler(layer, func) |
57 | 42 | elif model_type == 'pytorch': |
@@ -124,15 +109,9 @@ def convert_from_config(config): |
124 | 109 |
|
125 | 110 | model = None |
126 | 111 | if 'OnnxModel' in yamlConfig: |
127 | | - if __onnx_enabled__: |
128 | | - model = onnx_to_hls(yamlConfig) |
129 | | - else: |
130 | | - raise Exception("ONNX not found. Please install ONNX.") |
| 112 | + model = onnx_to_hls(yamlConfig) |
131 | 113 | elif 'PytorchModel' in yamlConfig: |
132 | | - if __pytorch_enabled__: |
133 | | - model = pytorch_to_hls(yamlConfig) |
134 | | - else: |
135 | | - raise Exception("PyTorch not found. Please install PyTorch.") |
| 114 | + model = pytorch_to_hls(yamlConfig) |
136 | 115 | else: |
137 | 116 | model = keras_to_hls(yamlConfig) |
138 | 117 |
|
@@ -174,6 +153,7 @@ def _check_model_config(model_config): |
174 | 153 | return model_config |
175 | 154 |
|
176 | 155 |
|
| 156 | +@requires('_keras') |
177 | 157 | def convert_from_keras_model( |
178 | 158 | model, |
179 | 159 | output_dir='my-hls-test', |
@@ -237,6 +217,7 @@ def convert_from_keras_model( |
237 | 217 | return keras_to_hls(config) |
238 | 218 |
|
239 | 219 |
|
| 220 | +@requires('_torch') |
240 | 221 | def convert_from_pytorch_model( |
241 | 222 | model, |
242 | 223 | output_dir='my-hls-test', |
@@ -308,6 +289,7 @@ def convert_from_pytorch_model( |
308 | 289 | return pytorch_to_hls(config) |
309 | 290 |
|
310 | 291 |
|
| 292 | +@requires('onnx') |
311 | 293 | def convert_from_onnx_model( |
312 | 294 | model, |
313 | 295 | output_dir='my-hls-test', |
@@ -371,6 +353,7 @@ def convert_from_onnx_model( |
371 | 353 | return onnx_to_hls(config) |
372 | 354 |
|
373 | 355 |
|
| 356 | +@requires('sr') |
374 | 357 | def convert_from_symbolic_expression( |
375 | 358 | expr, |
376 | 359 | n_symbols=None, |
|
0 commit comments