@@ -162,7 +162,9 @@ def __init__(
162162 self .w = w .astype (prec ) if w is not None else None
163163 self .b = b .astype (prec ) if b is not None else None
164164 self .idt = idt .astype (prec ) if idt is not None else None
165- self .activation_function = activation_function
165+ self .activation_function = (
166+ activation_function if activation_function is not None else "none"
167+ )
166168 self .resnet = resnet
167169 self .check_type_consistency ()
168170
@@ -354,6 +356,24 @@ def call(self, x: np.ndarray) -> np.ndarray:
354356
355357
356358class EmbeddingNet (NativeNet ):
359+ """The embedding network.
360+
361+ Parameters
362+ ----------
363+ in_dim
364+ Input dimension.
365+ neuron
366+ The number of neurons in each layer. The output dimension
367+ is the same as the dimension of the last layer.
368+ activation_function
369+ The activation function.
370+ resnet_dt
371+ Use time step at the resnet architecture.
372+ precision
373+ Floating point precision for the model paramters.
374+
375+ """
376+
357377 def __init__ (
358378 self ,
359379 in_dim ,
@@ -370,8 +390,8 @@ def __init__(
370390 layers .append (
371391 NativeLayer (
372392 rng .normal (size = (i_in , i_ot )),
373- b = rng .normal (size = (ii )),
374- idt = rng .normal (size = (ii )) if resnet_dt else None ,
393+ b = rng .normal (size = (i_ot )),
394+ idt = rng .normal (size = (i_ot )) if resnet_dt else None ,
375395 activation_function = activation_function ,
376396 resnet = True ,
377397 precision = precision ,
@@ -417,6 +437,95 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
417437 return obj
418438
419439
440+ class FittingNet (EmbeddingNet ):
441+ """The fitting network. It may be implemented as an embedding
442+ net connected with a linear output layer.
443+
444+ Parameters
445+ ----------
446+ in_dim
447+ Input dimension.
448+ out_dim
449+ Output dimension
450+ neuron
451+ The number of neurons in each hidden layer.
452+ activation_function
453+ The activation function.
454+ resnet_dt
455+ Use time step at the resnet architecture.
456+ precision
457+ Floating point precision for the model paramters.
458+ bias_out
459+ The last linear layer has bias.
460+
461+ """
462+
463+ def __init__ (
464+ self ,
465+ in_dim ,
466+ out_dim ,
467+ neuron : List [int ] = [24 , 48 , 96 ],
468+ activation_function : str = "tanh" ,
469+ resnet_dt : bool = False ,
470+ precision : str = DEFAULT_PRECISION ,
471+ bias_out : bool = True ,
472+ ):
473+ super ().__init__ (
474+ in_dim ,
475+ neuron = neuron ,
476+ activation_function = activation_function ,
477+ resnet_dt = resnet_dt ,
478+ precision = precision ,
479+ )
480+ rng = np .random .default_rng ()
481+ i_in , i_ot = neuron [- 1 ], out_dim
482+ self .layers .append (
483+ NativeLayer (
484+ rng .normal (size = (i_in , i_ot )),
485+ b = rng .normal (size = (i_ot )) if bias_out else None ,
486+ idt = None ,
487+ activation_function = None ,
488+ resnet = False ,
489+ precision = precision ,
490+ )
491+ )
492+ self .out_dim = out_dim
493+ self .bias_out = bias_out
494+
495+ def serialize (self ) -> dict :
496+ """Serialize the network to a dict.
497+
498+ Returns
499+ -------
500+ dict
501+ The serialized network.
502+ """
503+ return {
504+ "in_dim" : self .in_dim ,
505+ "out_dim" : self .out_dim ,
506+ "neuron" : self .neuron .copy (),
507+ "activation_function" : self .activation_function ,
508+ "resnet_dt" : self .resnet_dt ,
509+ "precision" : self .precision ,
510+ "bias_out" : self .bias_out ,
511+ "layers" : [layer .serialize () for layer in self .layers ],
512+ }
513+
514+ @classmethod
515+ def deserialize (cls , data : dict ) -> "FittingNet" :
516+ """Deserialize the network from a dict.
517+
518+ Parameters
519+ ----------
520+ data : dict
521+ The dict to deserialize from.
522+ """
523+ layers = data .pop ("layers" )
524+ obj = cls (** data )
525+ NativeNet .__init__ (obj , layers )
526+ return obj
527+
528+
420529class NetworkCollection :
421530 """A collection of networks for multiple elements.
422531
@@ -439,6 +548,7 @@ class NetworkCollection:
439548 NETWORK_TYPE_MAP : ClassVar [Dict [str , type ]] = {
440549 "network" : NativeNet ,
441550 "embedding_network" : EmbeddingNet ,
551+ "fitting_network" : FittingNet ,
442552 }
443553
444554 def __init__ (
0 commit comments