Skip to content

Commit 2b76193

Browse files
committed
refactor and add documentation to nn code
1 parent 090dfa1 commit 2b76193

File tree

1 file changed

+210
-60
lines changed
  • python/sdist/amici/jax

1 file changed

+210
-60
lines changed

python/sdist/amici/jax/nn.py

Lines changed: 210 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from pathlib import Path
22

3-
43
import equinox as eqx
54
import jax.numpy as jnp
65

7-
from amici._codegen.template import apply_template
86
from amici import amiciModulePath
7+
from amici._codegen.template import apply_template
98

109

1110
class Flatten(eqx.Module):
11+
"""Custom implementation of a torch.flatten layer for Equinox."""
12+
1213
start_dim: int
1314
end_dim: int
1415

@@ -27,13 +28,33 @@ def __call__(self, x):
2728

2829

2930
def tanhshrink(x: jnp.ndarray) -> jnp.ndarray:
31+
"""Custom implementation of the torch.nn.Tanhshrink activation function for JAX."""
3032
return x - jnp.tanh(x)
3133

3234

33-
def generate_equinox(nn_model: "NNModel", filename: Path | str, frozen_layers: dict = {}): # noqa: F821
35+
def generate_equinox(
36+
nn_model: "NNModel", # noqa: F821
37+
filename: Path | str,
38+
frozen_layers: dict[str, bool] | None = None,
39+
) -> None:
40+
"""
41+
Generate Equinox model file from petab_sciml neural network object.
42+
43+
:param nn_model:
44+
Neural network model in petab_sciml format
45+
:param filename:
46+
output filename for generated Equinox model
47+
:param frozen_layers:
48+
list of layer names to freeze during training
49+
:return:
50+
51+
"""
3452
# TODO: move to top level import and replace forward type definitions
3553
from petab_sciml import Layer
3654

55+
if frozen_layers is None:
56+
frozen_layers = {}
57+
3758
filename = Path(filename)
3859
layer_indent = 12
3960
node_indent = 8
@@ -84,6 +105,9 @@ def generate_equinox(nn_model: "NNModel", filename: Path | str, frozen_layers: d
84105

85106

86107
def _process_argval(v):
108+
"""
109+
Process argument value for layer instantiation string
110+
"""
87111
if isinstance(v, str):
88112
return f"'{v}'"
89113
if isinstance(v, bool):
@@ -92,11 +116,19 @@ def _process_argval(v):
92116

93117

94118
def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F821
95-
layer_map = {
96-
"Dropout1d": "eqx.nn.Dropout",
97-
"Dropout2d": "eqx.nn.Dropout",
98-
"Flatten": "amici.jax.Flatten",
99-
}
119+
"""
120+
Generate layer definition string for a given layer
121+
122+
:param layer:
123+
petab_sciml Layer object
124+
:param indent:
125+
indentation level for generated string
126+
:param ilayer:
127+
layer index for key generation
128+
129+
:return:
130+
string defining the layer in equinox syntax
131+
"""
100132
if layer.layer_type.startswith(
101133
("BatchNorm", "AlphaDropout", "InstanceNorm")
102134
):
@@ -110,6 +142,14 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F
110142
if layer.layer_type == "Bilinear":
111143
raise NotImplementedError("Bilinear layers not supported")
112144

145+
# mapping of layer names in sciml yaml format to equinox/custom amici implementations
146+
layer_map = {
147+
"Dropout1d": "eqx.nn.Dropout",
148+
"Dropout2d": "eqx.nn.Dropout",
149+
"Flatten": "amici.jax.Flatten",
150+
}
151+
152+
# mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations
113153
kwarg_map = {
114154
"Linear": {
115155
"bias": "use_bias",
@@ -125,10 +165,12 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F
125165
"normalized_shape": "shape",
126166
},
127167
}
168+
# list of keyword arguments to ignore when generating layer, as they are not supported in equinox (see above)
128169
kwarg_ignore = {
129170
"Dropout1d": ("inplace",),
130171
"Dropout2d": ("inplace",),
131172
}
173+
# construct argument string for layer instantiation
132174
kwargs = [
133175
f"{kwarg_map.get(layer.layer_type, {}).get(k, k)}={_process_argval(v)}"
134176
for k, v in layer.args.items()
@@ -150,70 +192,178 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F
150192
return f"{' ' * indent}'{layer.layer_id}': {layer_str}"
151193

152194

153-
def _generate_forward(node: "Node", indent, frozen_layers: dict = {}, layer_type=str) -> str: # noqa: F821
195+
def _format_function_call(
196+
var_name: str, fun_str: str, args: list, kwargs: list[str], indent: int
197+
) -> str:
198+
"""
199+
Utility function to format a function call assignment string.
200+
201+
:param var_name:
202+
name of the variable to assign the result to
203+
:param fun_str:
204+
string representation of the function to call
205+
:param args:
206+
list of positional arguments
207+
:param kwargs:
208+
list of keyword arguments as strings
209+
:param indent:
210+
indentation level for generated string
211+
212+
:return:
213+
formatted string representing the function call assignment
214+
"""
215+
args_str = ", ".join([f"{arg}" for arg in args])
216+
kwargs_str = ", ".join(kwargs)
217+
all_args = ", ".join(filter(None, [args_str, kwargs_str]))
218+
return f"{' ' * indent}{var_name} = {fun_str}({all_args})"
219+
220+
221+
def _process_layer_call(
222+
node: "Node", # noqa: F821
223+
layer_type: str,
224+
frozen_layers: dict[str, bool],
225+
) -> tuple[str, str]:
226+
"""
227+
Process a layer (call_module) node and return function string and optional tree string.
228+
229+
:param node:
230+
petab sciml Node object representing a layer call
231+
:param layer_type:
232+
petab sciml layer type of the node
233+
:param frozen_layers:
234+
dict of layer names to boolean indicating whether layer is frozen
235+
236+
:return:
237+
tuple of (function_string, tree_string) where tree_string is empty if no tree is needed
238+
"""
239+
fun_str = f"self.layers['{node.target}']"
240+
tree_string = ""
241+
242+
# Handle frozen layers
243+
if node.name in frozen_layers:
244+
if frozen_layers[node.name]:
245+
arr_attr = frozen_layers[node.name]
246+
get_lambda = f"lambda layer: getattr(layer, '{arr_attr}')"
247+
replacer = "replace_fn = lambda arr: jax.lax.stop_gradient(arr)"
248+
tree_string = f"tree_{node.name} = eqx.tree_at({get_lambda}, {fun_str}, {replacer})"
249+
fun_str = f"tree_{node.name}"
250+
else:
251+
fun_str = f"jax.lax.stop_gradient({fun_str})"
252+
253+
# Handle vmap for certain layer types
254+
if layer_type.startswith(("Conv", "Linear", "LayerNorm")):
255+
if layer_type in ("LayerNorm",):
256+
dims = f"len({fun_str}.shape)+1"
257+
elif layer_type == "Linear":
258+
dims = 2
259+
elif layer_type.endswith("1d"):
260+
dims = 3
261+
elif layer_type.endswith("2d"):
262+
dims = 4
263+
elif layer_type.endswith("3d"):
264+
dims = 5
265+
fun_str = f"(jax.vmap({fun_str}) if len({node.args[0]}.shape) == {dims} else {fun_str})"
266+
267+
return fun_str, tree_string
268+
269+
270+
def _process_activation_call(node: "Node") -> str: # noqa: F821
271+
"""
272+
Process an activation function (call_function/call_method) node and return function string.
273+
274+
:param node:
275+
petab sciml Node object representing an activation function call
276+
277+
:return:
278+
string representation of the activation function
279+
"""
280+
# Mapping of function names in sciml yaml format to equinox/custom amici implementations
281+
activation_map = {
282+
"hardtanh": "jax.nn.hard_tanh",
283+
"hardsigmoid": "jax.nn.hard_sigmoid",
284+
"hardswish": "jax.nn.hard_swish",
285+
"tanhshrink": "amici.jax.tanhshrink",
286+
"softsign": "jax.nn.soft_sign",
287+
}
288+
289+
# Validate hardtanh parameters
290+
if node.target == "hardtanh":
291+
if node.kwargs.pop("min_val", -1.0) != -1.0:
292+
raise NotImplementedError(
293+
"min_val != -1.0 not supported for hardtanh"
294+
)
295+
if node.kwargs.pop("max_val", 1.0) != 1.0:
296+
raise NotImplementedError(
297+
"max_val != 1.0 not supported for hardtanh"
298+
)
299+
300+
return activation_map.get(node.target, f"jax.nn.{node.target}")
301+
302+
303+
def _generate_forward(
304+
node: "Node", # noqa: F821
305+
indent,
306+
frozen_layers: dict[str, bool] | None = None,
307+
layer_type: str = "",
308+
) -> str:
309+
"""
310+
Generate forward pass line for a given node
311+
312+
:param node:
313+
petab sciml Node object representing a step in the forward pass
314+
:param indent:
315+
indentation level for generated string
316+
:param frozen_layers:
317+
dict of layer names to boolean indicating whether layer is frozen
318+
:param layer_type:
319+
petab sciml layer type of the node (only relevant for call_module nodes)
320+
321+
:return:
322+
string defining the forward pass implementation for the given node in equinox syntax
323+
"""
324+
if frozen_layers is None:
325+
frozen_layers = {}
326+
327+
# Handle placeholder nodes
154328
if node.op == "placeholder":
155329
# TODO: inconsistent target vs name
156330
return f"{' ' * indent}{node.name} = input"
157331

332+
# Handle output nodes
333+
if node.op == "output":
334+
args_str = ", ".join([f"{arg}" for arg in node.args])
335+
return f"{' ' * indent}{node.target} = {args_str}"
336+
337+
# Process layer calls
338+
tree_string = ""
158339
if node.op == "call_module":
159-
fun_str = f"self.layers['{node.target}']"
160-
if node.name in frozen_layers:
161-
if frozen_layers[node.name]:
162-
arr_attr = frozen_layers[node.name]
163-
get_lambda = f"lambda layer: getattr(layer, '{arr_attr}')"
164-
replacer = (
165-
"replace_fn = lambda arr: jax.lax.stop_gradient(arr)"
166-
)
167-
tree_string = f"tree_{node.name} = eqx.tree_at({get_lambda}, {fun_str}, {replacer})"
168-
fun_str = f"tree_{node.name}"
169-
else:
170-
fun_str = f"jax.lax.stop_gradient({fun_str})"
171-
tree_string = ""
172-
if layer_type.startswith(("Conv", "Linear", "LayerNorm")):
173-
if layer_type in ("LayerNorm",):
174-
dims = f"len({fun_str}.shape)+1"
175-
if layer_type == "Linear":
176-
dims = 2
177-
if layer_type.endswith(("1d",)):
178-
dims = 3
179-
elif layer_type.endswith(("2d",)):
180-
dims = 4
181-
elif layer_type.endswith("3d"):
182-
dims = 5
183-
fun_str = f"(jax.vmap({fun_str}) if len({node.args[0]}.shape) == {dims} else {fun_str})"
340+
fun_str, tree_string = _process_layer_call(
341+
node, layer_type, frozen_layers
342+
)
184343

344+
# Process activation function calls
185345
if node.op in ("call_function", "call_method"):
186-
map_fun = {
187-
"hardtanh": "jax.nn.hard_tanh",
188-
"hardsigmoid": "jax.nn.hard_sigmoid",
189-
"hardswish": "jax.nn.hard_swish",
190-
"tanhshrink": "amici.jax.tanhshrink",
191-
"softsign": "jax.nn.soft_sign",
192-
}
193-
if node.target == "hardtanh":
194-
if node.kwargs.pop("min_val", -1.0) != -1.0:
195-
raise NotImplementedError(
196-
"min_val != -1.0 not supported for hardtanh"
197-
)
198-
if node.kwargs.pop("max_val", 1.0) != 1.0:
199-
raise NotImplementedError(
200-
"max_val != 1.0 not supported for hardtanh"
201-
)
202-
fun_str = map_fun.get(node.target, f"jax.nn.{node.target}")
346+
fun_str = _process_activation_call(node)
203347

204-
args = ", ".join([f"{arg}" for arg in node.args])
348+
# Build kwargs list, filtering out unsupported arguments
205349
kwargs = [
206350
f"{k}={item}"
207351
for k, item in node.kwargs.items()
208352
if k not in ("inplace",)
209353
]
210-
if layer_type.startswith(("Dropout",)):
354+
355+
# Add key parameter for Dropout layers
356+
if layer_type.startswith("Dropout"):
211357
kwargs += ["key=key"]
212-
kwargs_str = ", ".join(kwargs)
358+
359+
# Format the function call
213360
if node.op in ("call_module", "call_function", "call_method"):
214-
if node.name in frozen_layers:
215-
return f"{' ' * indent}{tree_string}\n{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})"
216-
else:
217-
return f"{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})"
218-
if node.op == "output":
219-
return f"{' ' * indent}{node.target} = {args}"
361+
result = _format_function_call(
362+
node.name, fun_str, node.args, kwargs, indent
363+
)
364+
# Prepend tree_string if needed for frozen layers
365+
if tree_string:
366+
return f"{' ' * indent}{tree_string}\n{result}"
367+
return result
368+
369+
raise NotImplementedError(f"Operation {node.op} not supported")

0 commit comments

Comments
 (0)