11from collections .abc import Sequence
2+
23import keras
34from keras .saving import register_keras_serializable as serializable
45
56from bayesflow .types import Shape , Tensor
67from bayesflow .utils import (
78 expand_right_as ,
9+ find_network ,
10+ integrate ,
11+ jacobian_trace ,
812 keras_kwargs ,
913 optimal_transport ,
1014 serialize_value_or_type ,
1115 deserialize_value_or_type ,
1216)
1317from ..inference_network import InferenceNetwork
14- from .integrators import EulerIntegrator
15- from .integrators import RK2Integrator
16- from .integrators import RK4Integrator
1718
1819
1920@serializable (package = "bayesflow.networks" )
@@ -30,47 +31,71 @@ def __init__(
3031 self ,
3132 subnet : str | type = "mlp" ,
3233 base_distribution : str = "normal" ,
33- integrator : str = "euler" ,
3434 use_optimal_transport : bool = False ,
35+ loss_fn : str = "mse" ,
36+ integrate_kwargs : dict [str , any ] = None ,
3537 optimal_transport_kwargs : dict [str , any ] = None ,
3638 ** kwargs ,
3739 ):
3840 super ().__init__ (base_distribution = base_distribution , ** keras_kwargs (kwargs ))
3941
4042 self .use_optimal_transport = use_optimal_transport
41- self .optimal_transport_kwargs = optimal_transport_kwargs or {
42- "method" : "sinkhorn" ,
43- "cost" : "euclidean" ,
44- "regularization" : 0.1 ,
45- "max_steps" : 1000 ,
46- "tolerance" : 1e-4 ,
47- }
43+
44+ if integrate_kwargs is None :
45+ integrate_kwargs = {
46+ "method" : "rk45" ,
47+ "steps" : "adaptive" ,
48+ "tolerance" : 1e-3 ,
49+ "min_steps" : 10 ,
50+ "max_steps" : 100 ,
51+ }
52+
53+ self .integrate_kwargs = integrate_kwargs
54+
55+ if optimal_transport_kwargs is None :
56+ optimal_transport_kwargs = {
57+ "method" : "sinkhorn" ,
58+ "cost" : "euclidean" ,
59+ "regularization" : 0.1 ,
60+ "max_steps" : 100 ,
61+ "tolerance" : 1e-4 ,
62+ }
63+
64+ self .loss_fn = keras .losses .get (loss_fn )
65+
66+ self .optimal_transport_kwargs = optimal_transport_kwargs
4867
4968 self .seed_generator = keras .random .SeedGenerator ()
5069
51- match integrator :
52- case "euler" :
53- self .integrator = EulerIntegrator (subnet , ** kwargs )
54- case "rk2" :
55- self .integrator = RK2Integrator (subnet , ** kwargs )
56- case "rk4" :
57- self .integrator = RK4Integrator (subnet , ** kwargs )
58- case _:
59- raise NotImplementedError (f"No support for { integrator } integration" )
70+ self .subnet = find_network (subnet , ** kwargs .get ("subnet_kwargs" , {}))
71+ self .output_projector = keras .layers .Dense (units = None , bias_initializer = "zeros" )
6072
6173 # serialization: store all parameters necessary to call __init__
6274 self .config = {
6375 "base_distribution" : base_distribution ,
64- "integrator" : integrator ,
6576 "use_optimal_transport" : use_optimal_transport ,
6677 "optimal_transport_kwargs" : optimal_transport_kwargs ,
78+ "integrate_kwargs" : integrate_kwargs ,
6779 ** kwargs ,
6880 }
6981 self .config = serialize_value_or_type (self .config , "subnet" , subnet )
7082
7183 def build (self , xz_shape : Shape , conditions_shape : Shape = None ) -> None :
72- super ().build (xz_shape )
73- self .integrator .build (xz_shape , conditions_shape )
84+ super ().build (xz_shape , conditions_shape = conditions_shape )
85+
86+ self .output_projector .units = xz_shape [- 1 ]
87+ input_shape = list (xz_shape )
88+
89+ # construct time vector
90+ input_shape [- 1 ] += 1
91+ if conditions_shape is not None :
92+ input_shape [- 1 ] += conditions_shape [- 1 ]
93+
94+ input_shape = tuple (input_shape )
95+
96+ self .subnet .build (input_shape )
97+ out_shape = self .subnet .compute_output_shape (input_shape )
98+ self .output_projector .build (out_shape )
7499
75100 def get_config (self ):
76101 base_config = super ().get_config ()
@@ -81,32 +106,80 @@ def from_config(cls, config):
81106 config = deserialize_value_or_type (config , "subnet" )
82107 return cls (** config )
83108
109+ def velocity (self , xz : Tensor , t : float | Tensor , conditions : Tensor = None , training : bool = False ) -> Tensor :
110+ t = keras .ops .convert_to_tensor (t )
111+ t = expand_right_as (t , xz )
112+ t = keras .ops .broadcast_to (t , keras .ops .shape (xz )[:- 1 ] + (1 ,))
113+
114+ if conditions is None :
115+ xtc = keras .ops .concatenate ([xz , t ], axis = - 1 )
116+ else :
117+ xtc = keras .ops .concatenate ([xz , t , conditions ], axis = - 1 )
118+
119+ return self .output_projector (self .subnet (xtc , training = training ), training = training )
120+
121+ def _velocity_trace (
122+ self , xz : Tensor , t : Tensor , conditions : Tensor = None , max_steps : int = None , training : bool = False
123+ ) -> (Tensor , Tensor ):
124+ def f (x ):
125+ return self .velocity (x , t , conditions = conditions , training = training )
126+
127+ v , trace = jacobian_trace (f , xz , max_steps = max_steps , seed = self .seed_generator , return_output = True )
128+
129+ return v , keras .ops .expand_dims (trace , axis = - 1 )
130+
84131 def _forward (
85132 self , x : Tensor , conditions : Tensor = None , density : bool = False , training : bool = False , ** kwargs
86133 ) -> Tensor | tuple [Tensor , Tensor ]:
87- steps = kwargs .get ("steps" , 100 )
88-
89134 if density :
90- z , trace = self .integrator (x , conditions = conditions , steps = steps , density = True )
91- log_prob = self .base_distribution .log_prob (z )
92- log_density = log_prob + trace
135+
136+ def deltas (t , xz ):
137+ v , trace = self ._velocity_trace (xz , t , conditions = conditions , training = training )
138+ return {"xz" : v , "trace" : trace }
139+
140+ state = {"xz" : x , "trace" : keras .ops .zeros (keras .ops .shape (x )[:- 1 ] + (1 ,), dtype = keras .ops .dtype (x ))}
141+ state = integrate (deltas , state , start_time = 1.0 , stop_time = 0.0 , ** (self .integrate_kwargs | kwargs ))
142+
143+ z = state ["xz" ]
144+ log_density = self .base_distribution .log_prob (z ) + keras .ops .squeeze (state ["trace" ], axis = - 1 )
145+
93146 return z , log_density
94147
95- z = self .integrator (x , conditions = conditions , steps = steps , density = False )
148+ def deltas (t , xz ):
149+ return {"xz" : self .velocity (xz , t , conditions = conditions , training = training )}
150+
151+ state = {"xz" : x }
152+ state = integrate (deltas , state , start_time = 1.0 , stop_time = 0.0 , ** (self .integrate_kwargs | kwargs ))
153+
154+ z = state ["xz" ]
155+
96156 return z
97157
98158 def _inverse (
99159 self , z : Tensor , conditions : Tensor = None , density : bool = False , training : bool = False , ** kwargs
100160 ) -> Tensor | tuple [Tensor , Tensor ]:
101- steps = kwargs .get ("steps" , 100 )
102-
103161 if density :
104- x , trace = self .integrator (z , conditions = conditions , steps = steps , density = True , inverse = True )
105- log_prob = self .base_distribution .log_prob (z )
106- log_density = log_prob - trace
162+
163+ def deltas (t , xz ):
164+ v , trace = self ._velocity_trace (xz , t , conditions = conditions , training = training )
165+ return {"xz" : v , "trace" : trace }
166+
167+ state = {"xz" : z , "trace" : keras .ops .zeros (keras .ops .shape (z )[:- 1 ] + (1 ,), dtype = keras .ops .dtype (z ))}
168+ state = integrate (deltas , state , start_time = 0.0 , stop_time = 1.0 , ** (self .integrate_kwargs | kwargs ))
169+
170+ x = state ["xz" ]
171+ log_density = self .base_distribution .log_prob (z ) - keras .ops .squeeze (state ["trace" ], axis = - 1 )
172+
107173 return x , log_density
108174
109- x = self .integrator (z , conditions = conditions , steps = steps , density = False , inverse = True )
175+ def deltas (t , xz ):
176+ return {"xz" : self .velocity (xz , t , conditions = conditions , training = training )}
177+
178+ state = {"xz" : z }
179+ state = integrate (deltas , state , start_time = 0.0 , stop_time = 1.0 , ** (self .integrate_kwargs | kwargs ))
180+
181+ x = state ["xz" ]
182+
110183 return x
111184
112185 def compute_metrics (
@@ -118,7 +191,7 @@ def compute_metrics(
118191 else :
119192 # not pre-configured, resample
120193 x1 = x
121- x0 = keras . random . normal (keras .ops .shape ( x1 ), dtype = keras . ops . dtype (x1 ), seed = self .seed_generator )
194+ x0 = self . base_distribution . sample (keras .ops .shape (x1 ), seed = self .seed_generator )
122195
123196 if self .use_optimal_transport :
124197 x1 , x0 , conditions = optimal_transport (
@@ -133,9 +206,9 @@ def compute_metrics(
133206
134207 base_metrics = super ().compute_metrics (x1 , conditions , stage )
135208
136- predicted_velocity = self .integrator . velocity (x , t , conditions )
209+ predicted_velocity = self .velocity (x , t , conditions , training = stage == "training" )
137210
138- loss = keras . losses . mean_squared_error (target_velocity , predicted_velocity )
211+ loss = self . loss_fn (target_velocity , predicted_velocity )
139212 loss = keras .ops .mean (loss )
140213
141214 return base_metrics | {"loss" : loss }
0 commit comments