We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3c22035 commit 749c0a6Copy full SHA for 749c0a6
bayesflow/amortizers.py
@@ -1180,13 +1180,16 @@ def _compute_condition(self, input_dict, **kwargs):
1180
# p(local_n | data_n, hyper)
1181
if input_dict.get("hyper_parameters") is not None:
1182
_params = input_dict.get("hyper_parameters")
1183
- _conds = tf.stack([_params] * num_locals, axis=1)
+ _params = tf.expand_dims(_params, 1)
1184
+ _conds = tf.tile(_params, [1, num_locals, 1])
1185
local_summaries = tf.concat([local_summaries, _conds], axis=-1)
1186
+
1187
# Add shared parameters as conditions:
1188
# p(local_n | data_n, hyper, shared)
1189
if input_dict.get("shared_parameters") is not None:
1190
_params = input_dict.get("shared_parameters")
1191
1192
1193
1194
return local_summaries, global_summaries
1195
0 commit comments