@@ -45,21 +45,19 @@ class FlowMatching(InferenceNetwork):
4545 }
4646
4747 INTEGRATE_DEFAULT_CONFIG = {
48- "method" : "rk45" ,
49- "steps" : "adaptive" ,
50- "tolerance" : 1e-3 ,
51- "min_steps" : 10 ,
52- "max_steps" : 100 ,
48+ "method" : "euler" ,
49+ "steps" : 100 ,
5350 }
5451
5552 def __init__ (
5653 self ,
5754 subnet : str | type = "mlp" ,
5855 base_distribution : str = "normal" ,
59- use_optimal_transport : bool = False ,
56+ use_optimal_transport : bool = True ,
6057 loss_fn : str = "mse" ,
6158 integrate_kwargs : dict [str , any ] = None ,
6259 optimal_transport_kwargs : dict [str , any ] = None ,
60+ subnet_kwargs : dict [str , any ] = None ,
6361 ** kwargs ,
6462 ):
6563 """
@@ -97,23 +95,17 @@ def __init__(
9795
9896 self .use_optimal_transport = use_optimal_transport
9997
100- new_integrate_kwargs = FlowMatching .INTEGRATE_DEFAULT_CONFIG .copy ()
101- new_integrate_kwargs .update (integrate_kwargs or {})
102- self .integrate_kwargs = new_integrate_kwargs
103-
104- new_optimal_transport_kwargs = FlowMatching .OPTIMAL_TRANSPORT_DEFAULT_CONFIG .copy ()
105- new_optimal_transport_kwargs .update (optimal_transport_kwargs or {})
106- self .optimal_transport_kwargs = new_optimal_transport_kwargs
98+ self .integrate_kwargs = FlowMatching .INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
99+ self .optimal_transport_kwargs = FlowMatching .OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})
107100
108101 self .loss_fn = keras .losses .get (loss_fn )
109102
110103 self .seed_generator = keras .random .SeedGenerator ()
111104
105+ subnet_kwargs = subnet_kwargs or {}
106+
112107 if subnet == "mlp" :
113- subnet_kwargs = FlowMatching .MLP_DEFAULT_CONFIG .copy ()
114- subnet_kwargs .update (kwargs .get ("subnet_kwargs" , {}))
115- else :
116- subnet_kwargs = kwargs .get ("subnet_kwargs" , {})
108+ subnet_kwargs = FlowMatching .MLP_DEFAULT_CONFIG | subnet_kwargs
117109
118110 self .subnet = find_network (subnet , ** subnet_kwargs )
119111 self .output_projector = keras .layers .Dense (units = None , bias_initializer = "zeros" )
@@ -154,23 +146,23 @@ def from_config(cls, config):
154146 config = deserialize_value_or_type (config , "subnet" )
155147 return cls (** config )
156148
157- def velocity (self , xz : Tensor , t : float | Tensor , conditions : Tensor = None , training : bool = False ) -> Tensor :
158- t = keras .ops .convert_to_tensor (t )
159- t = expand_right_as (t , xz )
160- t = keras .ops .broadcast_to (t , keras .ops .shape (xz )[:- 1 ] + (1 ,))
149+ def velocity (self , xz : Tensor , time : float | Tensor , conditions : Tensor = None , training : bool = False ) -> Tensor :
150+ time = keras .ops .convert_to_tensor (time , dtype = keras . ops . dtype ( xz ) )
151+ time = expand_right_as (time , xz )
152+ time = keras .ops .broadcast_to (time , keras .ops .shape (xz )[:- 1 ] + (1 ,))
161153
162154 if conditions is None :
163- xtc = keras .ops .concatenate ([xz , t ], axis = - 1 )
155+ xtc = keras .ops .concatenate ([xz , time ], axis = - 1 )
164156 else :
165- xtc = keras .ops .concatenate ([xz , t , conditions ], axis = - 1 )
157+ xtc = keras .ops .concatenate ([xz , time , conditions ], axis = - 1 )
166158
167159 return self .output_projector (self .subnet (xtc , training = training ), training = training )
168160
169161 def _velocity_trace (
170- self , xz : Tensor , t : Tensor , conditions : Tensor = None , max_steps : int = None , training : bool = False
162+ self , xz : Tensor , time : Tensor , conditions : Tensor = None , max_steps : int = None , training : bool = False
171163 ) -> (Tensor , Tensor ):
172164 def f (x ):
173- return self .velocity (x , t , conditions = conditions , training = training )
165+ return self .velocity (x , time = time , conditions = conditions , training = training )
174166
175167 v , trace = jacobian_trace (f , xz , max_steps = max_steps , seed = self .seed_generator , return_output = True )
176168
@@ -181,8 +173,8 @@ def _forward(
181173 ) -> Tensor | tuple [Tensor , Tensor ]:
182174 if density :
183175
184- def deltas (t , xz ):
185- v , trace = self ._velocity_trace (xz , t , conditions = conditions , training = training )
176+ def deltas (time , xz ):
177+ v , trace = self ._velocity_trace (xz , time = time , conditions = conditions , training = training )
186178 return {"xz" : v , "trace" : trace }
187179
188180 state = {"xz" : x , "trace" : keras .ops .zeros (keras .ops .shape (x )[:- 1 ] + (1 ,), dtype = keras .ops .dtype (x ))}
@@ -193,8 +185,8 @@ def deltas(t, xz):
193185
194186 return z , log_density
195187
196- def deltas (t , xz ):
197- return {"xz" : self .velocity (xz , t , conditions = conditions , training = training )}
188+ def deltas (time , xz ):
189+ return {"xz" : self .velocity (xz , time = time , conditions = conditions , training = training )}
198190
199191 state = {"xz" : x }
200192 state = integrate (deltas , state , start_time = 1.0 , stop_time = 0.0 , ** (self .integrate_kwargs | kwargs ))
@@ -208,8 +200,8 @@ def _inverse(
208200 ) -> Tensor | tuple [Tensor , Tensor ]:
209201 if density :
210202
211- def deltas (t , xz ):
212- v , trace = self ._velocity_trace (xz , t , conditions = conditions , training = training )
203+ def deltas (time , xz ):
204+ v , trace = self ._velocity_trace (xz , time = time , conditions = conditions , training = training )
213205 return {"xz" : v , "trace" : trace }
214206
215207 state = {"xz" : z , "trace" : keras .ops .zeros (keras .ops .shape (z )[:- 1 ] + (1 ,), dtype = keras .ops .dtype (z ))}
@@ -220,8 +212,8 @@ def deltas(t, xz):
220212
221213 return x , log_density
222214
223- def deltas (t , xz ):
224- return {"xz" : self .velocity (xz , t , conditions = conditions , training = training )}
215+ def deltas (time , xz ):
216+ return {"xz" : self .velocity (xz , time = time , conditions = conditions , training = training )}
225217
226218 state = {"xz" : z }
227219 state = integrate (deltas , state , start_time = 0.0 , stop_time = 1.0 , ** (self .integrate_kwargs | kwargs ))
@@ -258,7 +250,7 @@ def compute_metrics(
258250
259251 base_metrics = super ().compute_metrics (x1 , conditions , stage )
260252
261- predicted_velocity = self .velocity (x , t , conditions , training = stage == "training" )
253+ predicted_velocity = self .velocity (x , time = t , conditions = conditions , training = stage == "training" )
262254
263255 loss = self .loss_fn (target_velocity , predicted_velocity )
264256 loss = keras .ops .mean (loss )
0 commit comments