11import typing
22from types import FunctionType
3- from typing import Any , Callable , Sequence , TypedDict
3+ from typing import Any , Callable , Sequence , TypedDict , overload
44
55
66class DefaultConfig (TypedDict , total = False ):
@@ -26,6 +26,14 @@ class DefaultConfig(TypedDict, total=False):
2626registry : dict [str , T_kv3_handler ] = {}
2727
2828
29+ @overload
30+ def register (cls : type ) -> type : ...
31+
32+
33+ @overload
34+ def register (cls : str ) -> Callable [[T_kv3_handler ], T_kv3_handler ]: ...
35+
36+
2937def register (cls : str | type ):
3038 """Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class.
3139
@@ -51,11 +59,13 @@ def my_layer_handler(layer, inp_tensors, out_tensors):
5159 ```
5260 """
5361
54- def deco (func : T_kv3_handler ):
62+ def deco (func ):
5563 if isinstance (cls , str ):
5664 registry [cls ] = func
5765 for k in getattr (func , 'handles' , ()):
5866 registry [k ] = func
67+ if isinstance (cls , type ):
68+ return cls
5969 return func
6070
6171 if isinstance (cls , type ):
@@ -79,7 +89,7 @@ def __call__(
7989 layer : 'keras.Layer' ,
8090 in_tensors : Sequence ['KerasTensor' ],
8191 out_tensors : Sequence ['KerasTensor' ],
82- ):
92+ ) -> tuple [ dict [ str , Any ], ...] :
8393 """Handle a keras layer. Return a tuple of dictionaries, each
8494 dictionary representing a layer (module) in the HLS model. One
8595 layer may correspond one or more dictionaries (e.g., layers with
@@ -114,8 +124,7 @@ def __call__(
114124 dict[str, Any] | tuple[dict[str, Any], ...]
115125 layer configuration(s) for the HLS model to be consumed by
116126 the ModelGraph constructor
117- """ # noqa: E501
118- import keras
127+ """
119128
120129 name = layer .name
121130 class_name = layer .__class__ .__name__
@@ -150,12 +159,23 @@ def __call__(
150159 ret = (config ,)
151160
152161 # If activation exists, append it
162+
163+ act_config , intermediate_tensor_name = self .maybe_get_activation_config (layer , out_tensors )
164+ if act_config is not None :
165+ ret [0 ]['output_keras_tensor_names' ] = [intermediate_tensor_name ]
166+ ret = * ret , act_config
167+
168+ return ret
169+
170+ def maybe_get_activation_config (self , layer , out_tensors ):
171+ import keras
172+
153173 activation = getattr (layer , 'activation' , None )
174+ name = layer .name
154175 if activation not in (keras .activations .linear , None ):
155176 assert len (out_tensors ) == 1 , f"Layer { name } has more than one output, but has an activation function"
156177 assert isinstance (activation , FunctionType ), f"Activation function for layer { name } is not a function"
157178 intermediate_tensor_name = f'{ out_tensors [0 ].name } _activation'
158- ret [0 ]['output_keras_tensor_names' ] = [intermediate_tensor_name ]
159179 act_cls_name = activation .__name__
160180 act_config = {
161181 'class_name' : 'Activation' ,
@@ -164,9 +184,8 @@ def __call__(
164184 'input_keras_tensor_names' : [intermediate_tensor_name ],
165185 'output_keras_tensor_names' : [out_tensors [0 ].name ],
166186 }
167- ret = * ret , act_config
168-
169- return ret
187+ return act_config , intermediate_tensor_name
188+ return None , None
170189
171190 def handle (
172191 self ,
@@ -175,3 +194,22 @@ def handle(
175194 out_tensors : Sequence ['KerasTensor' ],
176195 ) -> dict [str , Any ] | tuple [dict [str , Any ], ...]:
177196 return {}
197+
198+ def load_weight (self , layer : 'keras.Layer' , key : str ):
199+ """Load a weight from a layer.
200+
201+ Parameters
202+ ----------
203+ layer : keras.Layer
204+ The layer to load the weight from.
205+ key : str
206+ The key of the weight to load.
207+
208+ Returns
209+ -------
210+ np.ndarray
211+ The weight.
212+ """
213+ import keras
214+
215+ return keras .ops .convert_to_numpy (getattr (layer , key ))
0 commit comments