Skip to content

Commit 749c0a6

Browse files
committed
Add more efficient tile [skip ci]
1 parent 3c22035 commit 749c0a6

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

bayesflow/amortizers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,13 +1180,16 @@ def _compute_condition(self, input_dict, **kwargs):
11801180
# p(local_n | data_n, hyper)
11811181
if input_dict.get("hyper_parameters") is not None:
11821182
_params = input_dict.get("hyper_parameters")
1183-
_conds = tf.stack([_params] * num_locals, axis=1)
1183+
_params = tf.expand_dims(_params, 1)
1184+
_conds = tf.tile(_params, [1, num_locals, 1])
11841185
local_summaries = tf.concat([local_summaries, _conds], axis=-1)
1186+
11851187
# Add shared parameters as conditions:
11861188
# p(local_n | data_n, hyper, shared)
11871189
if input_dict.get("shared_parameters") is not None:
11881190
_params = input_dict.get("shared_parameters")
1189-
_conds = tf.stack([_params] * num_locals, axis=1)
1191+
_params = tf.expand_dims(_params, 1)
1192+
_conds = tf.tile(_params, [1, num_locals, 1])
11901193
local_summaries = tf.concat([local_summaries, _conds], axis=-1)
11911194
return local_summaries, global_summaries
11921195

0 commit comments

Comments
 (0)