@@ -198,7 +198,7 @@ class SplitNetwork(tf.keras.Model):
198198 of data to provide an individual network for each split of the data.
199199 """
200200
201- def __init__ (self , num_splits , split_data_configurator , network_type = InvariantNetwork , meta = {}, ** kwargs ):
201+ def __init__ (self , num_splits , split_data_configurator , network_type = InvariantNetwork , network_kwargs = {}, ** kwargs ):
202202 """Creates a composite network of `num_splits` sub-networks of type `network_type`, each with configuration
203203 specified by `meta`.
204204
@@ -207,14 +207,22 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
207207 num_splits : int
208208 The number if splits for the data, which will equal the number of sub-networks.
209209 split_data_configurator : callable
210- Function that takes the arguments `i` and `x` where `i` is the index of the network
211- and `x` are the inputs to the `SplitNetwork`. Should return the input for the corresponding network.
210+ Function that takes the arguments `i` and `x` where `i` is the index of the
211+ network and `x` are the inputs to the `SplitNetwork`. Should return the input
212+ for the corresponding network.
212213
213- For example, to achieve a network with is permutation-invariant both vertically (i.e., across rows)
214- and horizontally (i.e., across columns), one could to:
215- `def config(i, x):
216- TODO
214+ For example, to achieve a network with is permutation-invariant both
215+ vertically (i.e., across rows) and horizontally (i.e., across columns), one could to:
216+ `def split(i, x):
217+ selector = tf.where(x[:,:,0]==i, 1.0, 0.0)
218+ selected = x[:,:,1] * selector
219+ split_x = tf.stack((selector, selected), axis=-1)
220+ return split_x
217221 `
222+ where `x[:,:,0]` contains an integer indicating which split the data
223+ in `x[:,:,1]` belongs to. All values in `x[:,:,1]` that are not selected
224+ are set to zero. The selector is passed along with the modified data,
225+ indicating which rows belong to the split `i`.
218226 network_type : callable, optional, default: `InvariantNetowk`
219227 Type of neural network to use.
220228 meta : dict, optional, default: {}
@@ -227,7 +235,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
227235
228236 self .num_splits = num_splits
229237 self .split_data_configurator = split_data_configurator
230- self .networks = [network_type (meta ) for _ in range (num_splits )]
238+ self .networks = [network_type (** network_kwargs ) for _ in range (num_splits )]
231239
232240 def call (self , x ):
233241 """Performs a forward pass through the subnetworks and concatenates their output.
0 commit comments