11import typing
22from itertools import chain
3+ from types import FunctionType
34from typing import Any , Callable , Sequence
45
56if typing .TYPE_CHECKING :
@@ -154,7 +155,10 @@ def v2_call(
154155 self , layer : 'keras.layers.Layer' , inp_tensors : Sequence ['KerasTensor' ], out_tensors : Sequence ['KerasTensor' ]
155156 ):
156157 # keras v2 handlers fallback
157- print ("v2 handler" )
158+ print (f"v2 handler used for layer { layer .name } " )
159+
160+ import keras
161+
158162 config = layer .get_config ()
159163 layer_dict = {'config' : config , 'class_name' : layer .__class__ .__name__ }
160164
@@ -176,16 +180,22 @@ def get_weights_data(self, layer_name, var_name):
176180 return None
177181
178182 ret , _ = handler (layer_dict , input_names , input_shapes , reader )
179- ret ['outputs' ] = output_names
183+ ret ['output_keras_tensor_names' ] = output_names
184+ ret ['input_keras_tensor_names' ] = input_names
180185 ret = (ret ,)
181186
182187 activation = getattr (layer , 'activation' , None )
183188 if activation not in (keras .activations .linear , None ):
184- act_cls_name = activation .__class__ .__name__
189+ assert isinstance (activation , FunctionType ), f"Activation function for layer { layer .name } is not a function"
190+ intermediate_tensor_name = f'{ output_names [0 ]} _activation'
191+ ret [0 ]['output_keras_tensor_names' ] = (intermediate_tensor_name ,)
192+ act_cls_name = activation .__name__
185193 act_config = {
186194 'class_name' : 'Activation' ,
187195 'activation' : act_cls_name ,
188196 'name' : f'{ layer .name } _{ act_cls_name } ' ,
197+ 'input_keras_tensor_names' : (intermediate_tensor_name ,),
198+ 'output_keras_tensor_names' : output_names ,
189199 }
190200 ret = * ret , act_config
191201 return ret
@@ -212,19 +222,26 @@ def parse_keras_v3_model(model: 'keras.Model'):
212222 If a circular dependency is detected.
213223 """
214224
225+ assert model .built , "Model must be built before parsing"
226+
227+ import keras
228+
229+ if isinstance (model , keras .Sequential ):
230+ model = model ._functional # everything is functional under the hood lol
231+
215232 from .keras_to_hls import layer_handlers as v2_layer_handlers # Delayed import to avoid circular import
216233
217234 keras_v3_dispatcher = KerasV3HandlerDispatcher (v3_layer_handlers , v2_layer_handlers )
218235
219236 model_inputs , model_outputs , dependency , tensors = resolve_dependency_relation (model )
220237
221238 satisfied = set ()
222- total = len (tensors )
223239
224240 unique_name = UniqueName ()
225241
226242 layer_list : list [dict [str , Any ]] = []
227- while len (satisfied ) < total :
243+
244+ while any (t not in satisfied for t in model_outputs ):
228245 # Until all tensors in the model are satisfied
229246 for i , (layer_name , in_tensor_names , out_tensor_names ) in enumerate (dependency ):
230247 if not all (t in satisfied for t in in_tensor_names ):
@@ -237,13 +254,10 @@ def parse_keras_v3_model(model: 'keras.Model'):
237254 out_tensors = [tensors [t ] for t in out_tensor_names ]
238255
239256 _configs = keras_v3_dispatcher (layer , inp_tensors , out_tensors )
240- # Dispatch to v3 handler if available, else fallback to v2
241- # handler
257+ # Dispatch to v3 handler if available, else fallback to v2 handler
242258
243- # Prevent name conflicts. If a layer is used multiple times,
244- # add a suffix to the name At this stage, connections
245- # between modules are recorded by i/o keras tensor names
246- # (guaranteed unique), thus we can safely rename the layers
259+ # Prevent name conflicts. If a layer is used multiple times, add a suffix to the name.
260+ # At this stage connections between modules are recorded by i/o keras tensor names
247261 for _conf in _configs :
248262 _conf ['name' ] = unique_name (_conf ['name' ])
249263
0 commit comments