|
1 | 1 | """ |
2 | | - This module defines custom base layers to be used by the n3fit |
3 | | - Neural Network. |
4 | | - These layers can use the keras standard set of activation function |
5 | | - or implement their own. |
| 2 | +This module defines custom base layers to be used by the n3fit |
| 3 | +Neural Network. |
| 4 | +These layers can use the keras standard set of activation function |
| 5 | +or implement their own. |
6 | 6 |
|
7 | | - For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below. |
8 | | - This dictionary has the following structure: |
| 7 | +For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below. |
| 8 | +This dictionary has the following structure: |
9 | 9 |
|
10 | | - 'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} ) |
| 10 | + 'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} ) |
11 | 11 |
|
12 | | - In order to add custom activation functions, they must be added to |
13 | | - the `custom_activations` dictionary with the following structure: |
| 12 | +In order to add custom activation functions, they must be added to |
| 13 | +the `custom_activations` dictionary with the following structure: |
14 | 14 |
|
15 | | - 'name of the activation' : function |
| 15 | + 'name of the activation' : function |
16 | 16 |
|
17 | | - The names of the layer and the activation function are the ones to be used in the n3fit runcard. |
| 17 | +The names of the layer and the activation function are the ones to be used in the n3fit runcard. |
18 | 18 | """ |
19 | 19 |
|
20 | 20 | from keras.layers import Dense as KerasDense |
|
26 | 26 | from . import operations as ops |
27 | 27 | from .MetaLayer import MetaLayer |
28 | 28 |
|
| 29 | + |
29 | 30 | # Custom activation functions |
30 | 31 | def square_activation(x): |
31 | 32 | """Squares the input""" |
@@ -75,7 +76,12 @@ def ReshapedLSTM(input_tensor): |
75 | 76 |
|
76 | 77 |
|
77 | 78 | class Dense(KerasDense, MetaLayer): |
78 | | - pass |
| 79 | + |
| 80 | + def __init__(self, *args, **kwargs): |
| 81 | + # In Keras == 3.13, np.int() is not accepted by Dense |
| 82 | + if "units" in kwargs: |
| 83 | + kwargs["units"] = int(kwargs["units"]) |
| 84 | + super().__init__(*args, **kwargs) |
79 | 85 |
|
80 | 86 |
|
81 | 87 | def dense_per_flavour(basis_size=8, kernel_initializer="glorot_normal", **dense_kwargs): |
|
0 commit comments