11from pathlib import Path
22
3-
43import equinox as eqx
54import jax .numpy as jnp
65
7- from amici ._codegen .template import apply_template
86from amici import amiciModulePath
7+ from amici ._codegen .template import apply_template
98
109
1110class Flatten (eqx .Module ):
11+ """Custom implementation of a torch.flatten layer for Equinox."""
12+
1213 start_dim : int
1314 end_dim : int
1415
@@ -27,13 +28,33 @@ def __call__(self, x):
2728
2829
2930def tanhshrink (x : jnp .ndarray ) -> jnp .ndarray :
31+ """Custom implementation of the torch.nn.Tanhshrink activation function for JAX."""
3032 return x - jnp .tanh (x )
3133
3234
33- def generate_equinox (nn_model : "NNModel" , filename : Path | str , frozen_layers : dict = {}): # noqa: F821
35+ def generate_equinox (
36+ nn_model : "NNModel" , # noqa: F821
37+ filename : Path | str ,
38+ frozen_layers : dict [str , bool ] | None = None ,
39+ ) -> None :
40+ """
41+ Generate Equinox model file from petab_sciml neural network object.
42+
43+ :param nn_model:
44+ Neural network model in petab_sciml format
45+ :param filename:
46+ output filename for generated Equinox model
47+ :param frozen_layers:
48+ list of layer names to freeze during training
49+ :return:
50+
51+ """
3452 # TODO: move to top level import and replace forward type definitions
3553 from petab_sciml import Layer
3654
55+ if frozen_layers is None :
56+ frozen_layers = {}
57+
3758 filename = Path (filename )
3859 layer_indent = 12
3960 node_indent = 8
@@ -84,6 +105,9 @@ def generate_equinox(nn_model: "NNModel", filename: Path | str, frozen_layers: d
84105
85106
86107def _process_argval (v ):
108+ """
109+ Process argument value for layer instantiation string
110+ """
87111 if isinstance (v , str ):
88112 return f"'{ v } '"
89113 if isinstance (v , bool ):
@@ -92,11 +116,19 @@ def _process_argval(v):
92116
93117
94118def _generate_layer (layer : "Layer" , indent : int , ilayer : int ) -> str : # noqa: F821
95- layer_map = {
96- "Dropout1d" : "eqx.nn.Dropout" ,
97- "Dropout2d" : "eqx.nn.Dropout" ,
98- "Flatten" : "amici.jax.Flatten" ,
99- }
119+ """
120+ Generate layer definition string for a given layer
121+
122+ :param layer:
123+ petab_sciml Layer object
124+ :param indent:
125+ indentation level for generated string
126+ :param ilayer:
127+ layer index for key generation
128+
129+ :return:
130+ string defining the layer in equinox syntax
131+ """
100132 if layer .layer_type .startswith (
101133 ("BatchNorm" , "AlphaDropout" , "InstanceNorm" )
102134 ):
@@ -110,6 +142,14 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F
110142 if layer .layer_type == "Bilinear" :
111143 raise NotImplementedError ("Bilinear layers not supported" )
112144
145+ # mapping of layer names in sciml yaml format to equinox/custom amici implementations
146+ layer_map = {
147+ "Dropout1d" : "eqx.nn.Dropout" ,
148+ "Dropout2d" : "eqx.nn.Dropout" ,
149+ "Flatten" : "amici.jax.Flatten" ,
150+ }
151+
152+ # mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations
113153 kwarg_map = {
114154 "Linear" : {
115155 "bias" : "use_bias" ,
@@ -125,10 +165,12 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F
125165 "normalized_shape" : "shape" ,
126166 },
127167 }
168+ # list of keyword arguments to ignore when generating layer, as they are not supported in equinox (see above)
128169 kwarg_ignore = {
129170 "Dropout1d" : ("inplace" ,),
130171 "Dropout2d" : ("inplace" ,),
131172 }
173+ # construct argument string for layer instantiation
132174 kwargs = [
133175 f"{ kwarg_map .get (layer .layer_type , {}).get (k , k )} ={ _process_argval (v )} "
134176 for k , v in layer .args .items ()
@@ -150,70 +192,178 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F
150192 return f"{ ' ' * indent } '{ layer .layer_id } ': { layer_str } "
151193
152194
153- def _generate_forward (node : "Node" , indent , frozen_layers : dict = {}, layer_type = str ) -> str : # noqa: F821
195+ def _format_function_call (
196+ var_name : str , fun_str : str , args : list , kwargs : list [str ], indent : int
197+ ) -> str :
198+ """
199+ Utility function to format a function call assignment string.
200+
201+ :param var_name:
202+ name of the variable to assign the result to
203+ :param fun_str:
204+ string representation of the function to call
205+ :param args:
206+ list of positional arguments
207+ :param kwargs:
208+ list of keyword arguments as strings
209+ :param indent:
210+ indentation level for generated string
211+
212+ :return:
213+ formatted string representing the function call assignment
214+ """
215+ args_str = ", " .join ([f"{ arg } " for arg in args ])
216+ kwargs_str = ", " .join (kwargs )
217+ all_args = ", " .join (filter (None , [args_str , kwargs_str ]))
218+ return f"{ ' ' * indent } { var_name } = { fun_str } ({ all_args } )"
219+
220+
221+ def _process_layer_call (
222+ node : "Node" , # noqa: F821
223+ layer_type : str ,
224+ frozen_layers : dict [str , bool ],
225+ ) -> tuple [str , str ]:
226+ """
227+ Process a layer (call_module) node and return function string and optional tree string.
228+
229+ :param node:
230+ petab sciml Node object representing a layer call
231+ :param layer_type:
232+ petab sciml layer type of the node
233+ :param frozen_layers:
234+ dict of layer names to boolean indicating whether layer is frozen
235+
236+ :return:
237+ tuple of (function_string, tree_string) where tree_string is empty if no tree is needed
238+ """
239+ fun_str = f"self.layers['{ node .target } ']"
240+ tree_string = ""
241+
242+ # Handle frozen layers
243+ if node .name in frozen_layers :
244+ if frozen_layers [node .name ]:
245+ arr_attr = frozen_layers [node .name ]
246+ get_lambda = f"lambda layer: getattr(layer, '{ arr_attr } ')"
247+ replacer = "replace_fn = lambda arr: jax.lax.stop_gradient(arr)"
248+ tree_string = f"tree_{ node .name } = eqx.tree_at({ get_lambda } , { fun_str } , { replacer } )"
249+ fun_str = f"tree_{ node .name } "
250+ else :
251+ fun_str = f"jax.lax.stop_gradient({ fun_str } )"
252+
253+ # Handle vmap for certain layer types
254+ if layer_type .startswith (("Conv" , "Linear" , "LayerNorm" )):
255+ if layer_type in ("LayerNorm" ,):
256+ dims = f"len({ fun_str } .shape)+1"
257+ elif layer_type == "Linear" :
258+ dims = 2
259+ elif layer_type .endswith ("1d" ):
260+ dims = 3
261+ elif layer_type .endswith ("2d" ):
262+ dims = 4
263+ elif layer_type .endswith ("3d" ):
264+ dims = 5
265+ fun_str = f"(jax.vmap({ fun_str } ) if len({ node .args [0 ]} .shape) == { dims } else { fun_str } )"
266+
267+ return fun_str , tree_string
268+
269+
270+ def _process_activation_call (node : "Node" ) -> str : # noqa: F821
271+ """
272+ Process an activation function (call_function/call_method) node and return function string.
273+
274+ :param node:
275+ petab sciml Node object representing an activation function call
276+
277+ :return:
278+ string representation of the activation function
279+ """
280+ # Mapping of function names in sciml yaml format to equinox/custom amici implementations
281+ activation_map = {
282+ "hardtanh" : "jax.nn.hard_tanh" ,
283+ "hardsigmoid" : "jax.nn.hard_sigmoid" ,
284+ "hardswish" : "jax.nn.hard_swish" ,
285+ "tanhshrink" : "amici.jax.tanhshrink" ,
286+ "softsign" : "jax.nn.soft_sign" ,
287+ }
288+
289+ # Validate hardtanh parameters
290+ if node .target == "hardtanh" :
291+ if node .kwargs .pop ("min_val" , - 1.0 ) != - 1.0 :
292+ raise NotImplementedError (
293+ "min_val != -1.0 not supported for hardtanh"
294+ )
295+ if node .kwargs .pop ("max_val" , 1.0 ) != 1.0 :
296+ raise NotImplementedError (
297+ "max_val != 1.0 not supported for hardtanh"
298+ )
299+
300+ return activation_map .get (node .target , f"jax.nn.{ node .target } " )
301+
302+
303+ def _generate_forward (
304+ node : "Node" , # noqa: F821
305+ indent ,
306+ frozen_layers : dict [str , bool ] | None = None ,
307+ layer_type : str = "" ,
308+ ) -> str :
309+ """
310+ Generate forward pass line for a given node
311+
312+ :param node:
313+ petab sciml Node object representing a step in the forward pass
314+ :param indent:
315+ indentation level for generated string
316+ :param frozen_layers:
317+ dict of layer names to boolean indicating whether layer is frozen
318+ :param layer_type:
319+ petab sciml layer type of the node (only relevant for call_module nodes)
320+
321+ :return:
322+ string defining the forward pass implementation for the given node in equinox syntax
323+ """
324+ if frozen_layers is None :
325+ frozen_layers = {}
326+
327+ # Handle placeholder nodes
154328 if node .op == "placeholder" :
155329 # TODO: inconsistent target vs name
156330 return f"{ ' ' * indent } { node .name } = input"
157331
332+ # Handle output nodes
333+ if node .op == "output" :
334+ args_str = ", " .join ([f"{ arg } " for arg in node .args ])
335+ return f"{ ' ' * indent } { node .target } = { args_str } "
336+
337+ # Process layer calls
338+ tree_string = ""
158339 if node .op == "call_module" :
159- fun_str = f"self.layers['{ node .target } ']"
160- if node .name in frozen_layers :
161- if frozen_layers [node .name ]:
162- arr_attr = frozen_layers [node .name ]
163- get_lambda = f"lambda layer: getattr(layer, '{ arr_attr } ')"
164- replacer = (
165- "replace_fn = lambda arr: jax.lax.stop_gradient(arr)"
166- )
167- tree_string = f"tree_{ node .name } = eqx.tree_at({ get_lambda } , { fun_str } , { replacer } )"
168- fun_str = f"tree_{ node .name } "
169- else :
170- fun_str = f"jax.lax.stop_gradient({ fun_str } )"
171- tree_string = ""
172- if layer_type .startswith (("Conv" , "Linear" , "LayerNorm" )):
173- if layer_type in ("LayerNorm" ,):
174- dims = f"len({ fun_str } .shape)+1"
175- if layer_type == "Linear" :
176- dims = 2
177- if layer_type .endswith (("1d" ,)):
178- dims = 3
179- elif layer_type .endswith (("2d" ,)):
180- dims = 4
181- elif layer_type .endswith ("3d" ):
182- dims = 5
183- fun_str = f"(jax.vmap({ fun_str } ) if len({ node .args [0 ]} .shape) == { dims } else { fun_str } )"
340+ fun_str , tree_string = _process_layer_call (
341+ node , layer_type , frozen_layers
342+ )
184343
344+ # Process activation function calls
185345 if node .op in ("call_function" , "call_method" ):
186- map_fun = {
187- "hardtanh" : "jax.nn.hard_tanh" ,
188- "hardsigmoid" : "jax.nn.hard_sigmoid" ,
189- "hardswish" : "jax.nn.hard_swish" ,
190- "tanhshrink" : "amici.jax.tanhshrink" ,
191- "softsign" : "jax.nn.soft_sign" ,
192- }
193- if node .target == "hardtanh" :
194- if node .kwargs .pop ("min_val" , - 1.0 ) != - 1.0 :
195- raise NotImplementedError (
196- "min_val != -1.0 not supported for hardtanh"
197- )
198- if node .kwargs .pop ("max_val" , 1.0 ) != 1.0 :
199- raise NotImplementedError (
200- "max_val != 1.0 not supported for hardtanh"
201- )
202- fun_str = map_fun .get (node .target , f"jax.nn.{ node .target } " )
346+ fun_str = _process_activation_call (node )
203347
204- args = ", " . join ([ f" { arg } " for arg in node . args ])
348+ # Build kwargs list, filtering out unsupported arguments
205349 kwargs = [
206350 f"{ k } ={ item } "
207351 for k , item in node .kwargs .items ()
208352 if k not in ("inplace" ,)
209353 ]
210- if layer_type .startswith (("Dropout" ,)):
354+
355+ # Add key parameter for Dropout layers
356+ if layer_type .startswith ("Dropout" ):
211357 kwargs += ["key=key" ]
212- kwargs_str = ", " .join (kwargs )
358+
359+ # Format the function call
213360 if node .op in ("call_module" , "call_function" , "call_method" ):
214- if node .name in frozen_layers :
215- return f"{ ' ' * indent } { tree_string } \n { ' ' * indent } { node .name } = { fun_str } ({ args + ', ' + kwargs_str } )"
216- else :
217- return f"{ ' ' * indent } { node .name } = { fun_str } ({ args + ', ' + kwargs_str } )"
218- if node .op == "output" :
219- return f"{ ' ' * indent } { node .target } = { args } "
361+ result = _format_function_call (
362+ node .name , fun_str , node .args , kwargs , indent
363+ )
364+ # Prepend tree_string if needed for frozen layers
365+ if tree_string :
366+ return f"{ ' ' * indent } { tree_string } \n { result } "
367+ return result
368+
369+ raise NotImplementedError (f"Operation { node .op } not supported" )
0 commit comments