Skip to content

Commit a14befa

Browse files
committed
fix for keras3.13
1 parent 5d69b5d commit a14befa

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

n3fit/src/n3fit/backends/keras_backend/base_layers.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
"""
2-
This module defines custom base layers to be used by the n3fit
3-
Neural Network.
4-
These layers can use the keras standard set of activation function
5-
or implement their own.
2+
This module defines custom base layers to be used by the n3fit
3+
Neural Network.
4+
These layers can use the keras standard set of activation function
5+
or implement their own.
66
7-
For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8-
This dictionary has the following structure:
7+
For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8+
This dictionary has the following structure:
99
10-
'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
10+
'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
1111
12-
In order to add custom activation functions, they must be added to
13-
the `custom_activations` dictionary with the following structure:
12+
In order to add custom activation functions, they must be added to
13+
the `custom_activations` dictionary with the following structure:
1414
15-
'name of the activation' : function
15+
'name of the activation' : function
1616
17-
The names of the layer and the activation function are the ones to be used in the n3fit runcard.
17+
The names of the layer and the activation function are the ones to be used in the n3fit runcard.
1818
"""
1919

2020
from keras.layers import Dense as KerasDense
@@ -26,6 +26,7 @@
2626
from . import operations as ops
2727
from .MetaLayer import MetaLayer
2828

29+
2930
# Custom activation functions
3031
def square_activation(x):
3132
"""Squares the input"""
@@ -75,7 +76,12 @@ def ReshapedLSTM(input_tensor):
7576

7677

7778
class Dense(KerasDense, MetaLayer):
78-
pass
79+
80+
def __init__(self, *args, **kwargs):
81+
# In Keras == 3.13, np.int() is not accepted by Dense
82+
if "units" in kwargs:
83+
kwargs["units"] = int(kwargs["units"])
84+
super().__init__(*args, **kwargs)
7985

8086

8187
def dense_per_flavour(basis_size=8, kernel_initializer="glorot_normal", **dense_kwargs):

0 commit comments

Comments
 (0)