Skip to content

Commit eac0371

Browse files
committed
diffusion model input
1 parent 7f67c58 commit eac0371

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116
if subnet == "mlp":
117117
subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs
118118
self.subnet = find_network(subnet, **subnet_kwargs)
119-
self._concatenate_subnet_input = subnet_kwargs.get("concatenate_subnet_input", True)
119+
self._subnet_concatenated_input = subnet_kwargs.get("concatenated_input", True)
120120

121121
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
122122

@@ -150,7 +150,7 @@ def get_config(self):
150150
"prediction_type": self._prediction_type,
151151
"loss_type": self._loss_type,
152152
"integrate_kwargs": self.integrate_kwargs,
153-
"_concatenate_subnet_input": self._concatenate_subnet_input,
153+
"subnet_concatenated_input": self._subnet_concatenated_input,
154154
}
155155
return base_config | serialize(config)
156156

@@ -218,7 +218,7 @@ def prepare_subnet_input(self, xz: Tensor, log_snr: Tensor, conditions: Tensor =
218218
Tensor
219219
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
220220
"""
221-
if self._concatenate_subnet_input:
221+
if self._subnet_concatenated_input:
222222
if conditions is None:
223223
return tensor_utils.concatenate_valid([xz, log_snr], axis=-1)
224224
else:

0 commit comments

Comments
 (0)