forked from AMICI-dev/AMICI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnn.py
More file actions
420 lines (355 loc) · 13.2 KB
/
nn.py
File metadata and controls
420 lines (355 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
from pathlib import Path
import equinox as eqx
import jax.numpy as jnp
from amici import amiciModulePath
from ..exporters.template import apply_template
class Flatten(eqx.Module):
"""Custom implementation of a `torch.flatten` layer for Equinox."""
start_dim: int
end_dim: int
def __init__(self, start_dim: int, end_dim: int):
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def __call__(self, x):
if self.end_dim == -1:
return jnp.reshape(x, x.shape[: self.start_dim] + (-1,))
else:
return jnp.reshape(
x, x.shape[: self.start_dim] + (-1,) + x.shape[self.end_dim :]
)
def tanhshrink(x: jnp.ndarray) -> jnp.ndarray:
"""Custom implementation of the torch.nn.Tanhshrink activation function for JAX."""
return x - jnp.tanh(x)
def cat(tensors, axis: int = 0):
"""Alias for torch.cat using JAX's concatenate/stack function.
Handles both regular arrays and zero-dimensional (scalar) arrays by
using stack instead of concatenate for 0D arrays.
:param tensors:
List of arrays to concatenate
:param axis:
Dimension along which to concatenate (default: 0)
:return:
Concatenated array
"""
# Check if all tensors are 0-dimensional (scalars)
if all(jnp.ndim(t) == 0 for t in tensors):
# For 0D arrays, use stack instead of concatenate
return jnp.stack(tensors, axis=axis)
return jnp.concatenate(tensors, axis=axis)
def generate_equinox(
nn_model: "NNModel", # noqa: F821
filename: Path | str,
frozen_layers: dict[str, bool] | None = None,
) -> None:
"""
Generate Equinox model file from petab_sciml neural network object.
:param nn_model:
Neural network model in petab_sciml format
:param filename:
output filename for generated Equinox model
:param frozen_layers:
list of layer names to freeze during training
"""
# TODO: move to top level import and replace forward type definitions
from petab_sciml import Layer
if frozen_layers is None:
frozen_layers = {}
filename = Path(filename)
layer_indent = 12
node_indent = 8
layers = {layer.layer_id: layer for layer in nn_model.layers}
# Collect placeholder nodes to determine input handling
placeholder_nodes = [
node for node in nn_model.forward if node.op == "placeholder"
]
input_names = [node.name for node in placeholder_nodes]
# Generate input unpacking line
if len(input_names) == 1:
input_unpack = f"{input_names[0]} = input"
else:
input_unpack = f"{', '.join(input_names)} = input"
# Generate forward pass lines (excluding placeholder nodes)
forward_lines = [
_generate_forward(
node,
node_indent,
frozen_layers,
layers.get(
node.target,
Layer(layer_id="dummy", layer_type="Linear"),
).layer_type,
)
for node in nn_model.forward
]
# Filter out empty lines from placeholder processing
forward_lines = [line for line in forward_lines if line]
# Prepend input unpacking
forward_code = f"{' ' * node_indent}{input_unpack}\n" + "\n".join(
forward_lines
)
tpl_data = {
"MODEL_ID": nn_model.nn_model_id,
"LAYERS": ",\n".join(
[
_generate_layer(layer, layer_indent, ilayer)
for ilayer, layer in enumerate(nn_model.layers)
]
)[layer_indent:],
"FORWARD": forward_code[node_indent:],
"INPUT": ", ".join([f"'{inp.input_id}'" for inp in nn_model.inputs]),
"OUTPUT": ", ".join(
[
f"'{arg}'"
for arg in next(
node for node in nn_model.forward if node.op == "output"
).args
]
),
"N_LAYERS": len(nn_model.layers),
}
filename.parent.mkdir(parents=True, exist_ok=True)
apply_template(
Path(amiciModulePath) / "jax" / "nn.template.py",
filename,
tpl_data,
)
def _process_argval(v):
"""
Process argument value for layer instantiation string
"""
if isinstance(v, str):
return f"'{v}'"
if isinstance(v, bool):
return str(v)
return str(v)
def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F821
"""
Generate layer definition string for a given layer
:param layer:
petab_sciml Layer object
:param indent:
indentation level for generated string
:param ilayer:
layer index for key generation
:return:
string defining the layer in equinox syntax
"""
if layer.layer_type.startswith(
("BatchNorm", "AlphaDropout", "InstanceNorm")
):
raise NotImplementedError(
f"{layer.layer_type} layers currently not supported"
)
if layer.layer_type.startswith("MaxPool") and "dilation" in layer.args:
raise NotImplementedError("MaxPool layers with dilation not supported")
if layer.layer_type.startswith("Dropout") and "inplace" in layer.args:
raise NotImplementedError("Dropout layers with inplace not supported")
if layer.layer_type == "Bilinear":
raise NotImplementedError("Bilinear layers not supported")
# mapping of layer names in sciml yaml format to equinox/custom amici implementations
layer_map = {
"Dropout1d": "eqx.nn.Dropout",
"Dropout2d": "eqx.nn.Dropout",
"Flatten": "amici.jax.Flatten",
}
# mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations
kwarg_map = {
"Linear": {
"bias": "use_bias",
},
"Conv1d": {
"bias": "use_bias",
},
"Conv2d": {
"bias": "use_bias",
},
"LayerNorm": {
"elementwise_affine": "use_bias", # Deprecation warning - replace LayerNorm(elementwise_affine) with LayerNorm(use_bias)
"normalized_shape": "shape",
},
}
# list of keyword arguments to ignore when generating layer, as they are not supported in equinox (see above)
kwarg_ignore = {
"Dropout1d": ("inplace",),
"Dropout2d": ("inplace",),
}
# construct argument string for layer instantiation
kwargs = [
f"{kwarg_map.get(layer.layer_type, {}).get(k, k)}={_process_argval(v)}"
for k, v in layer.args.items()
if k not in kwarg_ignore.get(layer.layer_type, ())
]
# add key for initialization
if layer.layer_type in (
"Linear",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
):
kwargs += [f"key=keys[{ilayer}]"]
type_str = layer_map.get(layer.layer_type, f"eqx.nn.{layer.layer_type}")
layer_str = f"{type_str}({', '.join(kwargs)})"
return f"{' ' * indent}'{layer.layer_id}': {layer_str}"
def _format_function_call(
var_name: str, fun_str: str, args: list, kwargs: list[str], indent: int
) -> str:
"""
Utility function to format a function call assignment string.
:param var_name:
name of the variable to assign the result to
:param fun_str:
string representation of the function to call
:param args:
list of positional arguments
:param kwargs:
list of keyword arguments as strings
:param indent:
indentation level for generated string
:return:
formatted string representing the function call assignment
"""
args_str = ", ".join([f"{arg}" for arg in args])
kwargs_str = ", ".join(kwargs)
all_args = ", ".join(filter(None, [args_str, kwargs_str]))
return f"{' ' * indent}{var_name} = {fun_str}({all_args})"
def _process_layer_call(
node: "Node", # noqa: F821
layer_type: str,
frozen_layers: dict[str, bool],
) -> tuple[str, str]:
"""
Process a layer (call_module) node and return function string and optional tree string.
:param node:
petab sciml Node object representing a layer call
:param layer_type:
petab sciml layer type of the node
:param frozen_layers:
dict of layer names to boolean indicating whether layer is frozen
:return:
tuple of (function_string, tree_string) where tree_string is empty if no tree is needed
"""
fun_str = f"self.layers['{node.target}']"
tree_string = ""
# Handle frozen layers
if node.name in frozen_layers:
if frozen_layers[node.name]:
arr_attr = frozen_layers[node.name]
get_lambda = f"lambda layer: getattr(layer, '{arr_attr}')"
replacer = "replace_fn = lambda arr: jax.lax.stop_gradient(arr)"
tree_string = f"tree_{node.name} = eqx.tree_at({get_lambda}, {fun_str}, {replacer})"
fun_str = f"tree_{node.name}"
else:
fun_str = f"jax.lax.stop_gradient({fun_str})"
# Handle vmap for certain layer types
if layer_type.startswith(("Conv", "Linear", "LayerNorm")):
if layer_type in ("LayerNorm",):
dims = f"len({fun_str}.shape)+1"
elif layer_type == "Linear":
dims = 2
elif layer_type.endswith("1d"):
dims = 3
elif layer_type.endswith("2d"):
dims = 4
elif layer_type.endswith("3d"):
dims = 5
fun_str = f"(jax.vmap({fun_str}) if len({node.args[0]}.shape) == {dims} else {fun_str})"
return fun_str, tree_string
def _process_activation_call(node: "Node") -> str: # noqa: F821
"""
Process an activation function (call_function/call_method) node and return function string.
:param node:
petab sciml Node object representing an activation function call
:return:
string representation of the activation function
"""
# Mapping of function names in sciml yaml format to equinox/custom amici implementations
activation_map = {
"hardtanh": "jax.nn.hard_tanh",
"hardsigmoid": "jax.nn.hard_sigmoid",
"hardswish": "jax.nn.hard_swish",
"tanhshrink": "amici.jax.tanhshrink",
"softsign": "jax.nn.soft_sign",
"cat": "amici.jax.cat",
}
# Validate hardtanh parameters
if node.target == "hardtanh":
if node.kwargs.pop("min_val", -1.0) != -1.0:
raise NotImplementedError(
"min_val != -1.0 not supported for hardtanh"
)
if node.kwargs.pop("max_val", 1.0) != 1.0:
raise NotImplementedError(
"max_val != 1.0 not supported for hardtanh"
)
# Handle kwarg aliasing for cat (dim -> axis)
if node.target == "cat":
if "dim" in node.kwargs:
node.kwargs["axis"] = node.kwargs.pop("dim")
# Convert list of variable names to proper bracket-enclosed list
if isinstance(node.args[0], list):
# node.args[0] is a list like ['net_input1', 'net_input2']
# We need to convert it to a single string representing the list: [net_input1, net_input2]
node.args = tuple(
["[" + ", ".join(node.args[0]) + "]"] + list(node.args[1:])
)
return activation_map.get(node.target, f"jax.nn.{node.target}")
def _generate_forward(
node: "Node", # noqa: F821
indent,
frozen_layers: dict[str, bool] | None = None,
layer_type: str = "",
) -> str:
"""
Generate forward pass line for a given node
:param node:
petab sciml Node object representing a step in the forward pass
:param indent:
indentation level for generated string
:param frozen_layers:
dict of layer names to boolean indicating whether layer is frozen
:param layer_type:
petab sciml layer type of the node (only relevant for call_module nodes)
:return:
string defining the forward pass implementation for the given node in equinox syntax
"""
if frozen_layers is None:
frozen_layers = {}
# Handle placeholder nodes - skip individual processing, handled collectively in generate_equinox
if node.op == "placeholder":
return ""
# Handle output nodes
if node.op == "output":
args_str = ", ".join([f"{arg}" for arg in node.args])
return f"{' ' * indent}{node.target} = {args_str}"
# Process layer calls
tree_string = ""
if node.op == "call_module":
fun_str, tree_string = _process_layer_call(
node, layer_type, frozen_layers
)
# Process activation function calls
if node.op in ("call_function", "call_method"):
fun_str = _process_activation_call(node)
# Build kwargs list, filtering out unsupported arguments
kwargs = [
f"{k}={item}"
for k, item in node.kwargs.items()
if k not in ("inplace",)
]
# Add key parameter for Dropout layers
if layer_type.startswith("Dropout"):
kwargs += ["key=key"]
# Format the function call
if node.op in ("call_module", "call_function", "call_method"):
result = _format_function_call(
node.name, fun_str, node.args, kwargs, indent
)
# Prepend tree_string if needed for frozen layers
if tree_string:
return f"{' ' * indent}{tree_string}\n{result}"
return result
raise NotImplementedError(f"Operation {node.op} not supported")