Skip to content

Commit b5705f0

Browse files
committed
Merge branch 'Future' of https://github.com/stefanradev93/BayesFlow into Future
2 parents f11d7c4 + 9999e1c commit b5705f0

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

bayesflow/summary_networks.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class SplitNetwork(tf.keras.Model):
198198
of data to provide an individual network for each split of the data.
199199
"""
200200

201-
def __init__(self, num_splits, split_data_configurator, network_type=InvariantNetwork, meta={}, **kwargs):
201+
def __init__(self, num_splits, split_data_configurator, network_type=InvariantNetwork, network_kwargs={}, **kwargs):
202202
"""Creates a composite network of `num_splits` sub-networks of type `network_type`, each with configuration
203203
specified by `meta`.
204204
@@ -207,14 +207,22 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
207207
num_splits : int
208208
The number if splits for the data, which will equal the number of sub-networks.
209209
split_data_configurator : callable
210-
Function that takes the arguments `i` and `x` where `i` is the index of the network
211-
and `x` are the inputs to the `SplitNetwork`. Should return the input for the corresponding network.
210+
Function that takes the arguments `i` and `x` where `i` is the index of the
211+
network and `x` are the inputs to the `SplitNetwork`. Should return the input
212+
for the corresponding network.
212213
213-
For example, to achieve a network with is permutation-invariant both vertically (i.e., across rows)
214-
and horizontally (i.e., across columns), one could to:
215-
`def config(i, x):
216-
TODO
214+
For example, to achieve a network with is permutation-invariant both
215+
vertically (i.e., across rows) and horizontally (i.e., across columns), one could to:
216+
`def split(i, x):
217+
selector = tf.where(x[:,:,0]==i, 1.0, 0.0)
218+
selected = x[:,:,1] * selector
219+
split_x = tf.stack((selector, selected), axis=-1)
220+
return split_x
217221
`
222+
where `x[:,:,0]` contains an integer indicating which split the data
223+
in `x[:,:,1]` belongs to. All values in `x[:,:,1]` that are not selected
224+
are set to zero. The selector is passed along with the modified data,
225+
indicating which rows belong to the split `i`.
218226
network_type : callable, optional, default: `InvariantNetowk`
219227
Type of neural network to use.
220228
meta : dict, optional, default: {}
@@ -227,7 +235,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
227235

228236
self.num_splits = num_splits
229237
self.split_data_configurator = split_data_configurator
230-
self.networks = [network_type(meta) for _ in range(num_splits)]
238+
self.networks = [network_type(**network_kwargs) for _ in range(num_splits)]
231239

232240
def call(self, x):
233241
"""Performs a forward pass through the subnetworks and concatenates their output.

0 commit comments

Comments
 (0)