Skip to content

Commit d3c8881

Browse files
committed
Add hint on import failure
1 parent 014c1db commit d3c8881

File tree

8 files changed

+34
-39
lines changed

8 files changed

+34
-39
lines changed

hls4ml/converters/__init__.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import importlib
22
import os
3-
import warnings
43

54
import yaml
65

@@ -10,33 +9,19 @@
109
from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401
1110
from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401
1211
from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler
12+
from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401
1313
from hls4ml.converters.onnx_to_hls import parse_onnx_model # noqa: F401
14+
from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler
15+
from hls4ml.converters.pytorch_to_hls import ( # noqa: F401
16+
get_supported_pytorch_layers,
17+
pytorch_to_hls,
18+
register_pytorch_layer_handler,
19+
)
1420
from hls4ml.model import ModelGraph
1521
from hls4ml.utils.config import create_config
22+
from hls4ml.utils.dependency import requires
1623
from hls4ml.utils.symbolic_utils import LUTFunction
1724

18-
# ----------Make converters available if the libraries can be imported----------#
19-
try:
20-
from hls4ml.converters.pytorch_to_hls import ( # noqa: F401
21-
get_supported_pytorch_layers,
22-
pytorch_to_hls,
23-
register_pytorch_layer_handler,
24-
)
25-
26-
__pytorch_enabled__ = True
27-
except ImportError:
28-
warnings.warn("WARNING: Pytorch converter is not enabled!", stacklevel=1)
29-
__pytorch_enabled__ = False
30-
31-
try:
32-
from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401
33-
from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler
34-
35-
__onnx_enabled__ = True
36-
except ImportError:
37-
warnings.warn("WARNING: ONNX converter is not enabled!", stacklevel=1)
38-
__onnx_enabled__ = False
39-
4025
# ----------Layer handling register----------#
4126
model_types = ['keras', 'pytorch', 'onnx']
4227

@@ -51,7 +36,7 @@
5136
# and has 'handles' attribute
5237
# and is defined in this module (i.e., not imported)
5338
if callable(func) and hasattr(func, 'handles') and func.__module__ == lib.__name__:
54-
for layer in func.handles:
39+
for layer in func.handles: # type: ignore
5540
if model_type == 'keras':
5641
register_keras_layer_handler(layer, func)
5742
elif model_type == 'pytorch':
@@ -124,15 +109,9 @@ def convert_from_config(config):
124109

125110
model = None
126111
if 'OnnxModel' in yamlConfig:
127-
if __onnx_enabled__:
128-
model = onnx_to_hls(yamlConfig)
129-
else:
130-
raise Exception("ONNX not found. Please install ONNX.")
112+
model = onnx_to_hls(yamlConfig)
131113
elif 'PytorchModel' in yamlConfig:
132-
if __pytorch_enabled__:
133-
model = pytorch_to_hls(yamlConfig)
134-
else:
135-
raise Exception("PyTorch not found. Please install PyTorch.")
114+
model = pytorch_to_hls(yamlConfig)
136115
else:
137116
model = keras_to_hls(yamlConfig)
138117

@@ -174,6 +153,7 @@ def _check_model_config(model_config):
174153
return model_config
175154

176155

156+
@requires('_keras')
177157
def convert_from_keras_model(
178158
model,
179159
output_dir='my-hls-test',
@@ -237,6 +217,7 @@ def convert_from_keras_model(
237217
return keras_to_hls(config)
238218

239219

220+
@requires('_torch')
240221
def convert_from_pytorch_model(
241222
model,
242223
output_dir='my-hls-test',
@@ -308,6 +289,7 @@ def convert_from_pytorch_model(
308289
return pytorch_to_hls(config)
309290

310291

292+
@requires('onnx')
311293
def convert_from_onnx_model(
312294
model,
313295
output_dir='my-hls-test',
@@ -371,6 +353,7 @@ def convert_from_onnx_model(
371353
return onnx_to_hls(config)
372354

373355

356+
@requires('sr')
374357
def convert_from_symbolic_expression(
375358
expr,
376359
n_symbols=None,

hls4ml/converters/onnx_to_hls.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from hls4ml.model import ModelGraph
2+
from hls4ml.utils.dependency import requires
23

34

45
# ----------------------Helpers---------------------
@@ -17,6 +18,7 @@ def replace_char_inconsitency(name):
1718
return name.replace('.', '_')
1819

1920

21+
@requires('onnx')
2022
def get_onnx_attribute(operation, name, default=None):
2123
from onnx import helper
2224

@@ -73,6 +75,7 @@ def get_input_shape(graph, node):
7375
return rv
7476

7577

78+
@requires('onnx')
7679
def get_constant_value(graph, constant_name):
7780
tensor = next((x for x in graph.initializer if x.name == constant_name), None)
7881
from onnx import numpy_helper
@@ -258,6 +261,7 @@ def parse_onnx_model(onnx_model):
258261
return layer_list, input_layers, output_layers
259262

260263

264+
@requires('onnx')
261265
def onnx_to_hls(config):
262266
"""Convert onnx model to hls model from configuration.
263267

hls4ml/converters/pytorch_to_hls.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from hls4ml.model import ModelGraph
2+
from hls4ml.utils.dependency import requires
23

34

45
class PyTorchModelReader:
@@ -22,6 +23,7 @@ def get_weights_data(self, layer_name, var_name):
2223
return data
2324

2425

26+
@requires('_torch')
2527
class PyTorchFileReader(PyTorchModelReader): # Inherit get_weights_data method
2628
def __init__(self, config):
2729
import torch
@@ -103,6 +105,7 @@ def decorator(function):
103105
# ----------------------------------------------------------------
104106

105107

108+
@requires('_torch')
106109
def parse_pytorch_model(config, verbose=True):
107110
"""Convert PyTorch model to hls4ml ModelGraph.
108111
@@ -368,6 +371,7 @@ def parse_pytorch_model(config, verbose=True):
368371
return layer_list, input_layers
369372

370373

374+
@requires('_torch')
371375
def pytorch_to_hls(config):
372376
layer_list, input_layers = parse_pytorch_model(config)
373377
print('Creating HLS model')

hls4ml/model/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1 @@
11
from hls4ml.model.graph import HLSConfig, ModelGraph # noqa: F401
2-
3-
try:
4-
from hls4ml.model import profiling # noqa: F401
5-
6-
__profiling_enabled__ = True
7-
except ImportError:
8-
__profiling_enabled__ = False

hls4ml/model/quantizers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SaturationMode,
1515
XnorPrecisionType,
1616
)
17+
from hls4ml.utils.dependency import requires
1718

1819

1920
class Quantizer:
@@ -84,6 +85,7 @@ class QKerasQuantizer(Quantizer):
8485
config (dict): Config of the QKeras quantizer to wrap.
8586
"""
8687

88+
@requires('qkeras')
8789
def __init__(self, config):
8890
from qkeras.quantizers import get_quantizer
8991

@@ -131,6 +133,7 @@ class QKerasBinaryQuantizer(Quantizer):
131133
config (dict): Config of the QKeras quantizer to wrap.
132134
"""
133135

136+
@requires('qkeras')
134137
def __init__(self, config, xnor=False):
135138
from qkeras.quantizers import get_quantizer
136139

@@ -155,6 +158,7 @@ class QKerasPO2Quantizer(Quantizer):
155158
config (dict): Config of the QKeras quantizer to wrap.
156159
"""
157160

161+
@requires('qkeras')
158162
def __init__(self, config):
159163
from qkeras.quantizers import get_quantizer
160164

hls4ml/report/quartus_report.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import webbrowser
33
from ast import literal_eval
44

5+
from hls4ml.utils.dependency import requires
6+
57

68
def parse_quartus_report(hls_dir, write_to_file=True):
79
'''
@@ -39,6 +41,7 @@ def parse_quartus_report(hls_dir, write_to_file=True):
3941
return results
4042

4143

44+
@requires('quantus-report')
4245
def read_quartus_report(hls_dir, open_browser=False):
4346
'''
4447
Parse and print the Quartus report to print the report. Optionally open a browser.
@@ -89,6 +92,7 @@ def _find_project_dir(hls_dir):
8992
return top_func_name + '-fpga.prj'
9093

9194

95+
@requires('quantus-report')
9296
def read_js_object(js_script):
9397
'''
9498
Reads the JavaScript file and return a dictionary of variables definded in the script.

hls4ml/utils/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22

33
import hls4ml
4+
from hls4ml.utils.dependency import requires
45

56

67
def create_config(output_dir='my-hls-test', project_name='myproject', backend='Vivado', version='1.0.0', **kwargs):
@@ -44,6 +45,7 @@ def create_config(output_dir='my-hls-test', project_name='myproject', backend='V
4445
return config
4546

4647

48+
@requires('qkeras')
4749
def _get_precision_from_quantizer(quantizer):
4850
if isinstance(quantizer, str):
4951
import qkeras

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ optional-dependencies.doc = [
3434
"sphinx-rtd-theme",
3535
]
3636
optional-dependencies.HGQ = [ "hgq~=0.2.0" ]
37+
optional-dependencies.onnx = [ "onnx>=1.4" ]
3738
optional-dependencies.optimization = [
3839
"keras-tuner==1.1.3",
3940
"ortools==9.4.1874",

0 commit comments

Comments
 (0)