3232from pymc import Bernoulli , Data , Deterministic , Minibatch , fit , math , sample
3333from pymc import Model as PymcModel
3434from pymc import StudentT as PymcStudentT
35+ from scipy .special import erf
3536from scipy .stats import t
3637from typing_extensions import Self
3738
5354)
5455
5556UpdateMethods = Literal ["VI" , "MCMC" ]
57+ ActivationFunctions = Literal ["tanh" , "relu" , "sigmoid" , "gelu" ]
58+
59+
60+ # Module-level activation functions for pickling compatibility
61+ def _pymc_relu (x ):
62+ """ReLU activation function for PyMC."""
63+ return math .maximum (0 , x )
64+
65+
66+ def _pymc_gelu (x ):
67+ """GELU activation function for PyMC."""
68+ return 0.5 * x * (1 + math .erf (x / np .sqrt (2.0 )))
69+
70+
71+ def _numpy_relu (x : np .ndarray ) -> np .ndarray :
72+ """ReLU activation function for NumPy."""
73+ return np .maximum (0 , x )
74+
75+
76+ def _numpy_gelu (x : np .ndarray ) -> np .ndarray :
77+ """GELU activation function for NumPy."""
78+ return 0.5 * x * (1 + erf (x / np .sqrt (2.0 )))
79+
80+
81+ def _stable_sigmoid (x ):
82+ """Stable sigmoid activation function for NumPy."""
83+ return np .where (x >= 0 , 1 / (1 + np .exp (- x )), np .exp (x ) / (1 + np .exp (x )))
5684
5785
5886class Model (BaseModelSO , ABC ):
@@ -452,12 +480,19 @@ class BaseBayesianNeuralNetwork(Model, ABC):
452480 update_kwargs : Optional[dict], optional
453481 A dictionary of keyword arguments for the update method. For MCMC, it contains 'trace' settings.
454482 For VI, it contains both 'trace' and 'fit' settings.
483+ activation : str, optional
484+ The activation function to use for hidden layers. Supported values are: "tanh", "relu", "sigmoid", "gelu" (default is "tanh").
485+ use_residual_connections : bool, optional
486+ Whether to use residual connections in the network. Residual connections are only added when
487+ the layer output dimension is greater than or equal to the input dimension (default is False).
455488
456489 Notes
457490 -----
458- - The model uses tanh activation for hidden layers and sigmoid activation for the output layer.
491+ - The model uses the specified activation function for hidden layers and sigmoid activation for the output layer.
459492 - The output layer is designed for binary classification tasks, with probabilities modeled
460493 using a Bernoulli likelihood.
494+ - When use_residual_connections is True, residual connections are added to hidden layers where the output
495+ dimension is >= input dimension. For expanding dimensions, the residual is zero-padded.
461496 """
462497
463498 model_params : BnnParams
@@ -477,9 +512,23 @@ class BaseBayesianNeuralNetwork(Model, ABC):
477512 "adam" ,
478513 "adamax" ,
479514 ]
515+ _pymc_activations : ClassVar [dict ] = {
516+ "tanh" : math .tanh ,
517+ "relu" : _pymc_relu ,
518+ "sigmoid" : math .sigmoid ,
519+ "gelu" : _pymc_gelu ,
520+ }
521+ _numpy_activations : ClassVar [dict ] = {
522+ "tanh" : np .tanh ,
523+ "relu" : _numpy_relu ,
524+ "sigmoid" : _stable_sigmoid ,
525+ "gelu" : _numpy_gelu ,
526+ }
480527
481528 update_method : str = "VI"
482529 update_kwargs : Optional [dict ] = None
530+ activation : ActivationFunctions = "tanh"
531+ use_residual_connections : bool = False
483532
484533 _default_mcmc_trace_kwargs : ClassVar [dict ] = dict (
485534 tune = 500 ,
@@ -495,6 +544,8 @@ class BaseBayesianNeuralNetwork(Model, ABC):
495544 _default_variational_inference_fit_kwargs : ClassVar [dict ] = dict (method = "advi" )
496545
497546 _approx_history : np .ndarray = PrivateAttr (None )
547+ _numpy_activation_fn : Callable = PrivateAttr (None )
548+ _pymc_activation_fn : Callable = PrivateAttr (None )
498549
499550 class Config :
500551 arbitrary_types_allowed = True
@@ -569,6 +620,15 @@ def arrange_update_kwargs(self):
569620 else :
570621 raise ValueError (f"Unsupported pydantic version: { pydantic_version } " )
571622
623+ @field_validator ("activation" )
624+ @classmethod
625+ def validate_activation (cls , v ):
626+ if v not in cls ._pymc_activations .keys ():
627+ raise ValueError (
628+ f"Invalid activation function: { v } . Supported activations are: { list (cls ._pymc_activations .keys ())} "
629+ )
630+ return v
631+
572632 @property
573633 def approx_history (self ) -> Optional [np .ndarray ]:
574634 return self ._approx_history
@@ -585,10 +645,6 @@ def optimizer(self) -> Callable:
585645
586646 return _optimizer
587647
588- @classmethod
589- def _stable_sigmoid (cls , x ):
590- return np .where (x >= 0 , 1 / (1 + np .exp (- x )), np .exp (x ) / (1 + np .exp (x )))
591-
592648 @classmethod
593649 def get_layer_params_name (cls , layer_ind : PositiveInt ) -> Tuple [str , str ]:
594650 weight_layer_params_name = f"{ cls ._weight_var_name } _{ layer_ind } "
@@ -676,6 +732,14 @@ def input_dim(self) -> PositiveInt:
676732 """
677733 return self .model_params .bnn_layer_params [0 ].weight .shape [0 ]
678734
735+ def model_post_init (self , __context : Any ) -> None :
736+ """
737+ Initialize activation function PrivateAttr based on the activation setting.
738+ """
739+ # Initialize activation functions (always set to ensure they're available after model_copy)
740+ self ._numpy_activation_fn = self ._numpy_activations [self .activation ]
741+ self ._pymc_activation_fn = self ._pymc_activations [self .activation ]
742+
679743 def create_update_model (
680744 self , x : ArrayLike , y : Union [List [BinaryReward ], np .ndarray ], batch_size : Optional [PositiveInt ] = None
681745 ) -> PymcModel :
@@ -720,6 +784,8 @@ def create_update_model(
720784 w_shape = layer_params .weight .shape # without it n_features = 1 doesn't work
721785 b_shape = layer_params .bias .shape
722786 weight_layer_params_name , bias_layer_params_name = self .get_layer_params_name (layer_ind )
787+ input_dim = w_shape [0 ]
788+ output_dim = w_shape [1 ]
723789
724790 # For training, use shared weights and biases
725791 w = PymcStudentT (
@@ -732,7 +798,20 @@ def create_update_model(
732798 linear_transform = math .dot (next_layer_input , w ) + b
733799
734800 if layer_ind < len (self .model_params .bnn_layer_params ) - 1 :
735- next_layer_input = math .tanh (linear_transform )
801+ activated_output = self ._pymc_activation_fn (linear_transform )
802+
803+ # Add residual connection if enabled and dimensions allow
804+ if self .use_residual_connections and output_dim >= input_dim :
805+ if output_dim == input_dim :
806+ next_layer_input = activated_output + next_layer_input
807+ else :
808+ residual_padded = math .concatenate (
809+ [next_layer_input , math .zeros ((next_layer_input .shape [0 ], output_dim - input_dim ))],
810+ axis = 1 ,
811+ )
812+ next_layer_input = activated_output + residual_padded
813+ else :
814+ next_layer_input = activated_output
736815
737816 # Final output processing
738817 logit = Deterministic (self ._logit_var_name , linear_transform .squeeze ())
@@ -769,6 +848,8 @@ def sample_proba(self, context: np.ndarray) -> List[ProbabilityWeight]:
769848 # Sample weights and biases from StudentT distributions
770849 w_params = layer_params .weight .params
771850 b_params = layer_params .bias .params
851+ input_dim = layer_params .weight .shape [0 ]
852+ output_dim = layer_params .weight .shape [1 ]
772853
773854 # Sample weights and biases using scipy.stats
774855 w = t .rvs (
@@ -784,13 +865,25 @@ def sample_proba(self, context: np.ndarray) -> List[ProbabilityWeight]:
784865 # Linear transformation
785866 linear_transform = np .einsum ("...i,...ij->...j" , next_layer_input , w ) + b
786867
787- # Apply activation function (tanh for hidden layers, sigmoid for output)
868+ # Apply activation function for hidden layers, sigmoid for output
788869 if layer_ind < len (self .model_params .bnn_layer_params ) - 1 :
789- next_layer_input = np .tanh (linear_transform )
870+ activated_output = self ._numpy_activation_fn (linear_transform )
871+
872+ # Add residual connection if enabled and dimensions allow
873+ if self .use_residual_connections and output_dim >= input_dim :
874+ if output_dim == input_dim :
875+ next_layer_input = activated_output + next_layer_input
876+ else :
877+ residual_padded = np .pad (
878+ next_layer_input , ((0 , 0 ), (0 , output_dim - input_dim )), mode = "constant" , constant_values = 0
879+ )
880+ next_layer_input = activated_output + residual_padded
881+ else :
882+ next_layer_input = activated_output
790883 else :
791884 # Output layer - apply sigmoid
792885 weighted_sum = linear_transform .squeeze (- 1 )
793- prob = self . _stable_sigmoid (weighted_sum )
886+ prob = _stable_sigmoid (weighted_sum )
794887
795888 return list (zip (prob , weighted_sum ))
796889
@@ -884,6 +977,8 @@ def cold_start(
884977 update_method : UpdateMethods = "VI" ,
885978 update_kwargs : Optional [dict ] = None ,
886979 dist_params_init : Optional [Dict [str , float ]] = None ,
980+ activation : ActivationFunctions = "tanh" ,
981+ use_residual_connections : bool = False ,
887982 ** kwargs ,
888983 ) -> Self :
889984 """
@@ -901,6 +996,10 @@ def cold_start(
901996 Additional keyword arguments for the update method. Default is None.
902997 dist_params_init : Optional[Dict[str, float]], optional
903998 Initial distribution parameters for the network weights and biases. Default is None.
999+ activation : str
1000+ The activation function to use for hidden layers. Supported values are: "tanh", "relu", "sigmoid", "gelu" (default is "tanh").
1001+ use_residual_connections : bool
1002+ Whether to use residual connections in the network (default is False).
9041003 **kwargs
9051004 Additional keyword arguments for the BayesianNeuralNetwork constructor.
9061005
@@ -916,7 +1015,14 @@ def cold_start(
9161015 model_params = cls .create_model_params (
9171016 n_features = n_features , hidden_dim_list = hidden_dim_list , ** dist_params_init
9181017 )
919- return cls (model_params = model_params , update_method = update_method , update_kwargs = update_kwargs , ** kwargs )
1018+ return cls (
1019+ model_params = model_params ,
1020+ update_method = update_method ,
1021+ update_kwargs = update_kwargs ,
1022+ activation = activation ,
1023+ use_residual_connections = use_residual_connections ,
1024+ ** kwargs ,
1025+ )
9201026
9211027 def _reset (self ):
9221028 """
@@ -1001,6 +1107,8 @@ def cold_start(
10011107 update_method : UpdateMethods = "VI" ,
10021108 update_kwargs : Optional [dict ] = None ,
10031109 dist_params_init : Optional [Dict [str , float ]] = None ,
1110+ activation : ActivationFunctions = "tanh" ,
1111+ use_residual_connections : bool = False ,
10041112 ** kwargs ,
10051113 ) -> Self :
10061114 """
@@ -1020,6 +1128,10 @@ def cold_start(
10201128 Additional keyword arguments for the update method.
10211129 dist_params_init : Optional[Dict[str, float]], optional
10221130 Initial distribution parameters for the network weights and biases.
1131+ activation : str
1132+ The activation function to use for hidden layers. Supported values are: "tanh", "relu", "sigmoid", "gelu" (default is "tanh").
1133+ use_residual_connections : bool
1134+ Whether to use residual connections in the network (default is False).
10231135 **kwargs
10241136 Additional keyword arguments.
10251137
@@ -1028,13 +1140,16 @@ def cold_start(
10281140 BayesianNeuralNetworkMO
10291141 A multi-objective BNN with the specified number of objectives.
10301142 """
1143+
10311144 models = [
10321145 BayesianNeuralNetwork .cold_start (
10331146 n_features = n_features ,
10341147 hidden_dim_list = hidden_dim_list ,
10351148 update_method = update_method ,
10361149 update_kwargs = update_kwargs ,
10371150 dist_params_init = dist_params_init ,
1151+ activation = activation ,
1152+ use_residual_connections = use_residual_connections ,
10381153 )
10391154 for _ in range (n_objectives )
10401155 ]
0 commit comments