Skip to content

Commit c35d4b6

Browse files
committed
sequential and i/o tensor name parsing fix
1 parent 9dacf3e commit c35d4b6

File tree

3 files changed

+80
-27
lines changed

3 files changed

+80
-27
lines changed

hls4ml/converters/keras_v3/_base.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
import typing
2-
from typing import Any, Callable, Sequence
2+
from types import FunctionType
3+
from typing import Any, Callable, Sequence, TypedDict
4+
5+
6+
class DefaultConfig(TypedDict, total=False):
7+
name: str
8+
class_name: str
9+
module: str
10+
input_keras_tensor_names: list[str]
11+
input_shape: list[list[int]]
12+
output_keras_tensor_names: list[str]
13+
epsilon: float
14+
use_bias: bool
15+
data_format: str
16+
317

418
if typing.TYPE_CHECKING:
519
import keras
@@ -49,7 +63,7 @@ def deco(func: T_kv3_handler):
4963
return deco
5064

5165

52-
def maybe_add_attrs(config: dict[str, Any], obj: Any, *attrs: str):
66+
def maybe_add_attrs(config: dict[str, Any] | DefaultConfig, obj: Any, *attrs: str):
5367
for attr in attrs:
5468
if attr not in config and hasattr(obj, attr):
5569
config[attr] = getattr(obj, attr)
@@ -103,36 +117,55 @@ def __call__(
103117
""" # noqa: E501
104118
import keras
105119

106-
config0 = self.handle(layer, in_tensors, out_tensors)
107-
if isinstance(config0, tuple):
108-
return config0
109-
110120
name = layer.name
111121
class_name = layer.__class__.__name__
112122
module = layer.__module__
113-
config1 = {
123+
124+
default_config: DefaultConfig = {
114125
'name': name,
115126
'class_name': class_name,
116127
'module': module,
117128
'input_keras_tensor_names': [t.name for t in in_tensors],
118-
'input_shape': [list(t.shape[1:]) for t in in_tensors],
129+
'input_shape': [list(t.shape[1:]) for t in in_tensors], # type: ignore
119130
'output_keras_tensor_names': [t.name for t in out_tensors],
120131
}
121132

122-
maybe_add_attrs(config1, layer, 'epsilon', 'use_bias', 'data_format')
133+
maybe_add_attrs(default_config, layer, 'epsilon', 'use_bias', 'data_format')
123134

124-
config1.update(config0)
125-
ret = (config1,)
135+
mandatory_keys = ['name', 'class_name', 'output_keras_tensor_names', 'input_keras_tensor_names']
126136

137+
self.default_config = default_config
138+
config0 = self.handle(layer, in_tensors, out_tensors)
139+
del self.default_config
140+
141+
if isinstance(config0, tuple):
142+
for conf in config0:
143+
for key in mandatory_keys:
144+
assert key in conf, f"Key {key} missing from layer {name} handled by {self.__class__.__name__}"
145+
return config0
146+
147+
config = {}
148+
config.update(default_config)
149+
config.update(config0)
150+
ret = (config,)
151+
152+
# If activation exists, append it
127153
activation = getattr(layer, 'activation', None)
128154
if activation not in (keras.activations.linear, None):
129-
act_cls_name = activation.__class__.__name__
155+
assert len(out_tensors) == 1, f"Layer {name} has more than one output, but has an activation function"
156+
assert isinstance(activation, FunctionType), f"Activation function for layer {name} is not a function"
157+
intermediate_tensor_name = f'{out_tensors[0].name}_activation'
158+
ret[0]['output_keras_tensor_names'] = [intermediate_tensor_name]
159+
act_cls_name = activation.__name__
130160
act_config = {
131161
'class_name': 'Activation',
132162
'activation': act_cls_name,
133163
'name': f'{name}_{act_cls_name}',
164+
'input_keras_tensor_names': [intermediate_tensor_name],
165+
'output_keras_tensor_names': [out_tensors[0].name],
134166
}
135167
ret = *ret, act_config
168+
136169
return ret
137170

138171
def handle(

hls4ml/converters/keras_v3_to_hls.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing
22
from itertools import chain
3+
from types import FunctionType
34
from typing import Any, Callable, Sequence
45

56
if typing.TYPE_CHECKING:
@@ -154,7 +155,10 @@ def v2_call(
154155
self, layer: 'keras.layers.Layer', inp_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor']
155156
):
156157
# keras v2 handlers fallback
157-
print("v2 handler")
158+
print(f"v2 handler used for layer {layer.name}")
159+
160+
import keras
161+
158162
config = layer.get_config()
159163
layer_dict = {'config': config, 'class_name': layer.__class__.__name__}
160164

@@ -176,16 +180,22 @@ def get_weights_data(self, layer_name, var_name):
176180
return None
177181

178182
ret, _ = handler(layer_dict, input_names, input_shapes, reader)
179-
ret['outputs'] = output_names
183+
ret['output_keras_tensor_names'] = output_names
184+
ret['input_keras_tensor_names'] = input_names
180185
ret = (ret,)
181186

182187
activation = getattr(layer, 'activation', None)
183188
if activation not in (keras.activations.linear, None):
184-
act_cls_name = activation.__class__.__name__
189+
assert isinstance(activation, FunctionType), f"Activation function for layer {layer.name} is not a function"
190+
intermediate_tensor_name = f'{output_names[0]}_activation'
191+
ret[0]['output_keras_tensor_names'] = (intermediate_tensor_name,)
192+
act_cls_name = activation.__name__
185193
act_config = {
186194
'class_name': 'Activation',
187195
'activation': act_cls_name,
188196
'name': f'{layer.name}_{act_cls_name}',
197+
'input_keras_tensor_names': (intermediate_tensor_name,),
198+
'output_keras_tensor_names': output_names,
189199
}
190200
ret = *ret, act_config
191201
return ret
@@ -212,19 +222,26 @@ def parse_keras_v3_model(model: 'keras.Model'):
212222
If a circular dependency is detected.
213223
"""
214224

225+
assert model.built, "Model must be built before parsing"
226+
227+
import keras
228+
229+
if isinstance(model, keras.Sequential):
230+
model = model._functional # everything is functional under the hood lol
231+
215232
from .keras_to_hls import layer_handlers as v2_layer_handlers # Delayed import to avoid circular import
216233

217234
keras_v3_dispatcher = KerasV3HandlerDispatcher(v3_layer_handlers, v2_layer_handlers)
218235

219236
model_inputs, model_outputs, dependency, tensors = resolve_dependency_relation(model)
220237

221238
satisfied = set()
222-
total = len(tensors)
223239

224240
unique_name = UniqueName()
225241

226242
layer_list: list[dict[str, Any]] = []
227-
while len(satisfied) < total:
243+
244+
while any(t not in satisfied for t in model_outputs):
228245
# Until all tensors in the model are satisfied
229246
for i, (layer_name, in_tensor_names, out_tensor_names) in enumerate(dependency):
230247
if not all(t in satisfied for t in in_tensor_names):
@@ -237,13 +254,10 @@ def parse_keras_v3_model(model: 'keras.Model'):
237254
out_tensors = [tensors[t] for t in out_tensor_names]
238255

239256
_configs = keras_v3_dispatcher(layer, inp_tensors, out_tensors)
240-
# Dispatch to v3 handler if available, else fallback to v2
241-
# handler
257+
# Dispatch to v3 handler if available, else fallback to v2 handler
242258

243-
# Prevent name conflicts. If a layer is used multiple times,
244-
# add a suffix to the name At this stage, connections
245-
# between modules are recorded by i/o keras tensor names
246-
# (guaranteed unique), thus we can safely rename the layers
259+
# Prevent name conflicts. If a layer is used multiple times, add a suffix to the name.
260+
# At this stage connections between modules are recorded by i/o keras tensor names
247261
for _conf in _configs:
248262
_conf['name'] = unique_name(_conf['name'])
249263

hls4ml/utils/config.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22

33
import hls4ml
4+
import hls4ml.converters.keras_v3_to_hls
45

56

67
def create_config(output_dir='my-hls-test', project_name='myproject', backend='Vivado', version='1.0.0', **kwargs):
@@ -157,12 +158,17 @@ def config_from_keras_model(
157158

158159
if isinstance(model, dict):
159160
model_arch = model
161+
reader = hls4ml.converters.KerasModelReader(model)
162+
layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader)
160163
else:
161-
model_arch = json.loads(model.to_json())
164+
import keras
162165

163-
reader = hls4ml.converters.KerasModelReader(model)
164-
165-
layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader)
166+
if keras.__version__ > '3.0':
167+
layer_list, *_ = hls4ml.converters.parse_keras_v3_model(model)
168+
else:
169+
model_arch = json.loads(model.to_json())
170+
reader = hls4ml.converters.KerasModelReader(model)
171+
layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader)
166172

167173
def make_layer_config(layer):
168174
cls_name = layer['class_name']

0 commit comments

Comments
 (0)