Skip to content

Commit 86843fc

Browse files
committed
make tf and qkeras optionl, stop assuming keras is tf.keras
1 parent 0ccda4d commit 86843fc

File tree

8 files changed

+35
-20
lines changed

8 files changed

+35
-20
lines changed

hls4ml/converters/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def parse_yaml_config(config_file):
9393
"""
9494

9595
def construct_keras_model(loader, node):
96-
from tensorflow.keras.models import load_model
97-
9896
model_str = loader.construct_scalar(node)
99-
return load_model(model_str)
97+
import keras
98+
99+
return keras.models.load_model(model_str)
100100

101101
yaml.add_constructor('!keras_model', construct_keras_model, Loader=yaml.SafeLoader)
102102

hls4ml/model/optimizer/passes/qkeras.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import tensorflow as tf
32

43
from hls4ml.model.layers import ApplyAlpha
54
from hls4ml.model.optimizer import ConfigurableOptimizerPass, OptimizerPass, register_pass
@@ -113,6 +112,8 @@ def match(self, node):
113112
def transform(self, model, node):
114113
# The quantizer has to be applied to set the scale attribute
115114
# This must be applied to the _unquantized_ weights to obtain the correct scale
115+
import tensorflow as tf
116+
116117
quantizer = node.weights['weight'].quantizer.quantizer_fn # get QKeras quantizer
117118
weights = node.weights['weight'].data_unquantized # get weights
118119
qweights = quantizer(tf.convert_to_tensor(weights))

hls4ml/model/profiling.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
from hls4ml.model.layers import GRU, LSTM, SeparableConv1D, SeparableConv2D
1414

1515
try:
16-
import qkeras
17-
from tensorflow import keras
16+
import keras
1817

19-
__tf_profiling_enabled__ = True
18+
__keras_profiling_enabled__ = True
2019
except ImportError:
21-
__tf_profiling_enabled__ = False
20+
__keras_profiling_enabled__ = False
2221

2322
try:
2423
import torch
@@ -27,6 +26,19 @@
2726
except ImportError:
2827
__torch_profiling_enabled__ = False
2928

29+
try:
30+
import qkeras
31+
32+
__qkeras_profiling_enabled__ = True
33+
except ImportError:
34+
__qkeras_profiling_enabled__ = False
35+
36+
_activations = list()
37+
if __keras_profiling_enabled__:
38+
_activations.append(keras.layers.Activation)
39+
if __qkeras_profiling_enabled__:
40+
_activations.append(qkeras.qactivations)
41+
3042

3143
def get_unoptimized_hlsmodel(model):
3244
from hls4ml.converters import convert_from_config
@@ -482,7 +494,7 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'):
482494
if hls_model_present:
483495
data = weights_hlsmodel(hls_model_unoptimized, fmt='summary', plot=plot)
484496
elif model_present:
485-
if __tf_profiling_enabled__ and isinstance(model, keras.Model):
497+
if __keras_profiling_enabled__ and isinstance(model, keras.Model):
486498
data = weights_keras(model, fmt='summary', plot=plot)
487499
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential):
488500
data = weights_torch(model, fmt='summary', plot=plot)
@@ -520,7 +532,7 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'):
520532
if X is not None:
521533
print("Profiling activations" + before)
522534
data = None
523-
if __tf_profiling_enabled__ and isinstance(model, keras.Model):
535+
if __keras_profiling_enabled__ and isinstance(model, keras.Model):
524536
data = activations_keras(model, X, fmt='summary', plot=plot)
525537
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential):
526538
data = activations_torch(model, X, fmt='summary', plot=plot)
@@ -590,7 +602,7 @@ def get_ymodel_keras(keras_model, X):
590602
if (
591603
hasattr(layer, 'activation')
592604
and layer.activation is not None
593-
and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation))
605+
and not isinstance(layer, _activations)
594606
and layer.activation.__name__ != 'linear'
595607
):
596608
tmp_activation = layer.activation

hls4ml/optimization/dsp_aware_pruning/keras/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
import numpy as np
55
import tensorflow as tf
66

7-
# Enables printing of loss tensors during custom training loop
8-
from tensorflow.python.ops.numpy_ops import np_config
9-
107
import hls4ml.optimization.dsp_aware_pruning.keras.utils as utils
118
from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES
129
from hls4ml.optimization.dsp_aware_pruning.keras.builder import build_optimizable_model, remove_custom_regularizers
@@ -15,7 +12,6 @@
1512
from hls4ml.optimization.dsp_aware_pruning.keras.reduction import reduce_model
1613
from hls4ml.optimization.dsp_aware_pruning.scheduler import OptimizationScheduler
1714

18-
np_config.enable_numpy_behavior()
1915
default_regularization_range = np.logspace(-6, -2, num=16).tolist()
2016

2117

hls4ml/utils/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import json
22

3-
import qkeras
4-
53
import hls4ml
64

75

@@ -48,6 +46,8 @@ def create_config(output_dir='my-hls-test', project_name='myproject', backend='V
4846

4947
def _get_precision_from_quantizer(quantizer):
5048
if isinstance(quantizer, str):
49+
import qkeras
50+
5151
quantizer_obj = qkeras.get_quantizer(quantizer)
5252
quantizer = {}
5353
# Some activations are classes with get_config method

hls4ml/writer/catapult_writer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,9 @@ def keras_model_representer(dumper, keras_model):
889889
return dumper.represent_scalar('!keras_model', model_path)
890890

891891
try:
892-
from tensorflow.keras import Model as KerasModel
892+
import keras
893+
894+
KerasModel = keras.models.Model
893895

894896
yaml.add_multi_representer(KerasModel, keras_model_representer)
895897
except Exception:

hls4ml/writer/quartus_writer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1327,7 +1327,9 @@ def keras_model_representer(dumper, keras_model):
13271327
return dumper.represent_scalar('!keras_model', model_path)
13281328

13291329
try:
1330-
from tensorflow.keras import Model as KerasModel
1330+
import keras
1331+
1332+
KerasModel = keras.models.Model
13311333

13321334
yaml.add_multi_representer(KerasModel, keras_model_representer)
13331335
except Exception:

hls4ml/writer/vivado_writer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,9 @@ def keras_model_representer(dumper, keras_model):
815815
return dumper.represent_scalar('!keras_model', model_path)
816816

817817
try:
818-
from tensorflow.keras import Model as KerasModel
818+
import keras
819+
820+
KerasModel = keras.models.Model
819821

820822
yaml.add_multi_representer(KerasModel, keras_model_representer)
821823
except Exception:

0 commit comments

Comments
 (0)