Skip to content

Commit c313871

Browse files
committed
correctly update default flow matching kwargs
1 parent f065b54 commit c313871

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,13 @@ def __init__(
6666

6767
self.use_optimal_transport = use_optimal_transport
6868

69-
self.integrate_kwargs = integrate_kwargs or FlowMatching.INTEGRATE_DEFAULT_CONFIG.copy()
70-
self.optimal_transport_kwargs = optimal_transport_kwargs or FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG.copy()
69+
new_integrate_kwargs = FlowMatching.INTEGRATE_DEFAULT_CONFIG.copy()
70+
new_integrate_kwargs.update(integrate_kwargs)
71+
self.integrate_kwargs = new_integrate_kwargs
72+
73+
new_optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG.copy()
74+
new_optimal_transport_kwargs.update(optimal_transport_kwargs)
75+
self.optimal_transport_kwargs = new_optimal_transport_kwargs
7176

7277
self.loss_fn = keras.losses.get(loss_fn)
7378

0 commit comments

Comments
 (0)