@@ -274,7 +274,7 @@ class StudentTArray(PyBanditsBaseModel):
274274 nu : Union [List [PositiveFloat ], List [List [PositiveFloat ]]]
275275
276276 @staticmethod
277- def convert_list_to_array (input_list : Union [List [float ], List [List [float ]]]) -> bool :
277+ def maybe_convert_list_to_array (input_list : Union [List [float ], List [List [float ]]]) -> bool :
278278 if len (input_list ) == 0 :
279279 is_valid_input = False
280280
@@ -292,19 +292,16 @@ def convert_list_to_array(input_list: Union[List[float], List[List[float]]]) ->
292292 else :
293293 raise ValueError ("Input list must be a 1D or 2D list with the same length for all inner lists." )
294294
295- @model_validator (mode = "after " )
295+ @model_validator (mode = "before " )
296296 @classmethod
297297 def validate_input_shapes (cls , values ):
298- if pydantic_version == PYDANTIC_VERSION_1 :
299- mu_arr = cls .convert_list_to_array (values .get ("mu" ))
300- sigma_arr = cls .convert_list_to_array (values .get ("sigma" ))
301- nu_arr = cls .convert_list_to_array (values .get ("nu" ))
302- elif pydantic_version == PYDANTIC_VERSION_2 :
303- mu_arr = cls .convert_list_to_array (values .mu )
304- sigma_arr = cls .convert_list_to_array (values .sigma )
305- nu_arr = cls .convert_list_to_array (values .nu )
306- else :
307- raise ValueError (f"Unsupported pydantic version: { pydantic_version } " )
298+ mu_input = values .get ("mu" )
299+ sigma_input = values .get ("sigma" )
300+ nu_input = values .get ("nu" )
301+
302+ mu_arr = cls .maybe_convert_list_to_array (mu_input )
303+ sigma_arr = cls .maybe_convert_list_to_array (sigma_input )
304+ nu_arr = cls .maybe_convert_list_to_array (nu_input )
308305
309306 if (mu_arr .shape != sigma_arr .shape ) or (mu_arr .shape != nu_arr .shape ):
310307 raise ValueError (
@@ -315,6 +312,9 @@ def validate_input_shapes(cls, values):
315312 if any (dim_len == 0 for dim_len in mu_arr .shape ):
316313 raise ValueError ("mu, sigma, and nu must have at least one element in every dimension." )
317314
315+ for key , value in zip (["mu" , "sigma" , "nu" ], [mu_input , sigma_input , nu_input ]):
316+ if isinstance (value , np .ndarray ):
317+ values [key ] = value .tolist ()
318318 return values
319319
320320 @classmethod
@@ -331,9 +331,9 @@ def cold_start(
331331 if any (dim_len == 0 for dim_len in shape ):
332332 raise ValueError ("shape of mu, sigma, and nu must have at least one element in every dimension." )
333333
334- mu = np .full (shape , mu ). tolist ()
335- sigma = np .full (shape , sigma ). tolist ()
336- nu = np .full (shape , nu ). tolist ()
334+ mu = np .full (shape , mu )
335+ sigma = np .full (shape , sigma )
336+ nu = np .full (shape , nu )
337337 return cls (mu = mu , sigma = sigma , nu = nu )
338338
339339 @property
@@ -449,9 +449,6 @@ class BaseBayesianNeuralNetwork(Model, ABC):
449449 )
450450
451451 _default_variational_inference_fit_kwargs : ClassVar [dict ] = dict (method = "advi" )
452- _default_variational_inference_trace_kwargs : ClassVar [dict ] = dict (
453- draws = 1000 , progressbar = False , return_inferencedata = False
454- )
455452
456453 _approx_history : np .ndarray = PrivateAttr (None )
457454
@@ -470,12 +467,7 @@ def arrange_update_kwargs(cls, values):
470467 update_kwargs = dict ()
471468
472469 if update_method == "VI" :
473- update_kwargs ["trace" ] = {
474- ** cls ._default_variational_inference_trace_kwargs ,
475- ** update_kwargs .get ("trace" , {}),
476- }
477470 update_kwargs ["fit" ] = {** cls ._default_variational_inference_fit_kwargs , ** update_kwargs .get ("fit" , {})}
478-
479471 optimizer_type = update_kwargs .get ("optimizer_type" , None )
480472
481473 if optimizer_type is not None :
@@ -507,10 +499,6 @@ def arrange_update_kwargs(self):
507499 self .update_kwargs = dict ()
508500
509501 if self .update_method == "VI" :
510- self .update_kwargs ["trace" ] = {
511- ** self ._default_variational_inference_trace_kwargs ,
512- ** self .update_kwargs .get ("trace" , {}),
513- }
514502 self .update_kwargs ["fit" ] = {
515503 ** self ._default_variational_inference_fit_kwargs ,
516504 ** self .update_kwargs .get ("fit" , {}),
@@ -673,14 +661,14 @@ def create_update_model(
673661 3. Apply sigmoid activation at the output
674662 4. Use Bernoulli likelihood for binary classification
675663 """
676-
664+ y = np . array ( y , dtype = np . int32 )
677665 with PymcModel () as _model :
678666 # Define data variables
679667 if batch_size is None :
680668 bnn_output = Data ("bnn_output" , y )
681669 bnn_input = Data ("bnn_input" , x )
682670 else :
683- bnn_input , bnn_output = Minibatch (x , np . array ( y ). astype ( "int32" ) , batch_size = batch_size )
671+ bnn_input , bnn_output = Minibatch (x , y , batch_size = batch_size )
684672
685673 next_layer_input = bnn_input
686674
@@ -750,7 +738,7 @@ def sample_proba(self, context: np.ndarray) -> List[ProbabilityWeight]:
750738 )
751739
752740 # Linear transformation
753- linear_transform = np .sum ( next_layer_input [ ..., None ] * w , axis = 1 ) + b
741+ linear_transform = np .einsum ( " ...i,...ij->...j" , next_layer_input , w ) + b
754742
755743 # Apply activation function (tanh for hidden layers, sigmoid for output)
756744 if layer_ind < len (self .model_params .bnn_layer_params ) - 1 :
@@ -797,29 +785,53 @@ def _update(self, context: np.ndarray, rewards: List[BinaryReward]):
797785 else :
798786 approx = fit (** update_kwargs ["fit" ])
799787
800- trace = approx .sample (** update_kwargs ["trace" ])
801788 self ._approx_history = approx .hist
789+ approx_mean_eval = approx .mean .eval ()
790+ approx_std_eval = approx .std .eval ()
791+ approx_posterior_mapping = {
792+ param : (approx_mean_eval [slice_ ], approx_std_eval [slice_ ])
793+ for (param , (_ , slice_ , _ , _ )) in approx .ordering .items ()
794+ }
795+ for layer_ind , layer_params in enumerate (self .model_params .bnn_layer_params ):
796+ weight_layer_params_name , bias_layer_params_name = self .get_layer_params_name (layer_ind )
797+ w_shape = layer_params .weight .shape
798+ b_shape = layer_params .bias .shape
799+ w_mu = approx_posterior_mapping [weight_layer_params_name ][0 ].reshape (w_shape )
800+ w_sigma = approx_posterior_mapping [weight_layer_params_name ][1 ].reshape (w_shape )
801+ b_mu = approx_posterior_mapping [bias_layer_params_name ][0 ].reshape (b_shape )
802+ b_sigma = approx_posterior_mapping [bias_layer_params_name ][1 ].reshape (b_shape )
803+ layer_params .weight = StudentTArray (
804+ mu = w_mu , sigma = w_sigma , nu = self .model_params .bnn_layer_params [layer_ind ].weight .nu
805+ )
806+ layer_params .bias = StudentTArray (
807+ mu = b_mu , sigma = b_sigma , nu = self .model_params .bnn_layer_params [layer_ind ].bias .nu
808+ )
809+ self .model_params .bnn_layer_params [layer_ind ] = layer_params
802810 elif self .update_method == "MCMC" :
803811 # MCMC
804812 trace = sample (** self .update_kwargs ["trace" ])
813+
814+ for layer_ind , layer_params in enumerate (self .model_params .bnn_layer_params ):
815+ weight_layer_params_name , bias_layer_params_name = self .get_layer_params_name (layer_ind )
816+
817+ w_mu = np .mean (trace [weight_layer_params_name ], axis = 0 )
818+ w_sigma = np .std (trace [weight_layer_params_name ], axis = 0 )
819+ layer_params .weight = StudentTArray (
820+ mu = w_mu .tolist (),
821+ sigma = w_sigma .tolist (),
822+ nu = self .model_params .bnn_layer_params [layer_ind ].weight .nu ,
823+ )
824+
825+ b_mu = np .mean (trace [bias_layer_params_name ], axis = 0 )
826+ b_sigma = np .std (trace [bias_layer_params_name ], axis = 0 )
827+ layer_params .bias = StudentTArray (
828+ mu = b_mu .tolist (),
829+ sigma = b_sigma .tolist (),
830+ nu = self .model_params .bnn_layer_params [layer_ind ].bias .nu ,
831+ )
805832 else :
806833 raise ValueError ("Invalid update method." )
807834
808- for layer_ind , layer_params in enumerate (self .model_params .bnn_layer_params ):
809- weight_layer_params_name , bias_layer_params_name = self .get_layer_params_name (layer_ind )
810-
811- w_mu = np .mean (trace [weight_layer_params_name ], axis = 0 )
812- w_sigma = np .std (trace [weight_layer_params_name ], axis = 0 )
813- layer_params .weight = StudentTArray (
814- mu = w_mu .tolist (), sigma = w_sigma .tolist (), nu = self .model_params .bnn_layer_params [layer_ind ].weight .nu
815- )
816-
817- b_mu = np .mean (trace [bias_layer_params_name ], axis = 0 )
818- b_sigma = np .std (trace [bias_layer_params_name ], axis = 0 )
819- layer_params .bias = StudentTArray (
820- mu = b_mu .tolist (), sigma = b_sigma .tolist (), nu = self .model_params .bnn_layer_params [layer_ind ].bias .nu
821- )
822-
823835 @classmethod
824836 def cold_start (
825837 cls ,
0 commit comments