@@ -198,8 +198,8 @@ def __init__(
198198 features from an input set using a set of seed vectors (typically one for a single summary) with ``summary_dim``
199199 output dimensions.
200200
201- Recommnded : When using transformers as summary networks, you may want to use a smaller learning rate
202- during training, e.g., setting ``default_lr=1e-5 `` in a ``Trainer`` instance.
201+ Recommended : When using transformers as summary networks, you may want to use a smaller learning rate
202+ during training, e.g., setting ``default_lr=1e-4 `` in a ``Trainer`` instance.
203203
204204 Parameters
205205 ----------
@@ -211,7 +211,7 @@ def __init__(
211211
212212 ``attention_settings=dict(num_heads=4, key_dim=32)``
213213
214- You may also want to include dropout regularization in small-to-medium data regimes:
214+ You may also want to include stronger dropout regularization in small-to-medium data regimes:
215215
216216 ``attention_settings=dict(num_heads=4, key_dim=32, dropout=0.1)``
217217
@@ -235,7 +235,7 @@ def __init__(
235235 The number of self-attention blocks to use before pooling.
236236 num_inducing_points : int or None, optional, default: 32
237237 The number of inducing points. Should be lower than the smallest set size.
238- If ``None`` selected, a vanilla self-attenion block (SAB) will be used, otherwise
238+ If ``None`` selected, a vanilla self-attention block (SAB) will be used, otherwise
239239 ISAB blocks will be used. For ``num_attention_blocks > 1``, we currently recommend
240240 always using some number of inducing points.
241241 num_seeds : int, optional, default: 1
@@ -355,9 +355,9 @@ def __init__(
355355 num_dense_s1 = num_dense_s1 ,
356356 num_dense_s2 = num_dense_s2 ,
357357 num_dense_s3 = num_dense_s3 ,
358- dense_s1_args = defaults .DEFAULT_SETTING_DENSE_INVARIANT if dense_s1_args is None else dense_s1_args ,
359- dense_s2_args = defaults .DEFAULT_SETTING_DENSE_INVARIANT if dense_s2_args is None else dense_s2_args ,
360- dense_s3_args = defaults .DEFAULT_SETTING_DENSE_INVARIANT if dense_s3_args is None else dense_s3_args ,
358+ dense_s1_args = defaults .DEFAULT_SETTING_DENSE_DEEP_SET if dense_s1_args is None else dense_s1_args ,
359+ dense_s2_args = defaults .DEFAULT_SETTING_DENSE_DEEP_SET if dense_s2_args is None else dense_s2_args ,
360+ dense_s3_args = defaults .DEFAULT_SETTING_DENSE_DEEP_SET if dense_s3_args is None else dense_s3_args ,
361361 pooling_fun = pooling_fun ,
362362 )
363363
@@ -369,7 +369,7 @@ def __init__(
369369 self .out_layer = Dense (summary_dim , activation = "linear" )
370370 self .summary_dim = summary_dim
371371
372- def call (self , x ):
372+ def call (self , x , ** kwargs ):
373373 """Performs the forward pass of a learnable deep invariant transformation consisting of
374374 a sequence of equivariant transforms followed by an invariant transform.
375375
@@ -385,10 +385,10 @@ def call(self, x):
385385 """
386386
387387 # Pass through series of augmented equivariant transforms
388- out_equiv = self .equiv_layers (x )
388+ out_equiv = self .equiv_layers (x , ** kwargs )
389389
390390 # Pass through final invariant layer
391- out = self .out_layer (self .inv (out_equiv ) )
391+ out = self .out_layer (self .inv (out_equiv , ** kwargs ), ** kwargs )
392392
393393 return out
394394
@@ -443,7 +443,7 @@ def __init__(
443443 conv_settings : dict or None, optional, default: None
444444 The arguments passed to the `MultiConv1D` internal networks. If `None`,
445445 defaults will be used from `default_settings`. If a dictionary is provided,
446- it should contain the followin keys:
446+ it should contain the following keys:
447447 - layer_args (dict) : arguments for `tf.keras.layers.Conv1D` without kernel_size
448448 - min_kernel_size (int) : the minimum kernel size (>= 1)
449449 - max_kernel_size (int) : the maximum kernel size
@@ -508,8 +508,8 @@ class SplitNetwork(tf.keras.Model):
508508 of data to provide an individual network for each split of the data.
509509 """
510510
511- def __init__ (self , num_splits , split_data_configurator , network_type = InvariantNetwork , network_kwargs = {}, ** kwargs ):
512- """Creates a composite network of `num_splits` sub-networks of type `network_type`, each with configuration
511+ def __init__ (self , num_splits , split_data_configurator , network_type = DeepSet , network_kwargs = {}, ** kwargs ):
512+ """Creates a composite network of `num_splits` subnetworks of type `network_type`, each with configuration
513513 specified by `meta`.
514514
515515 Parameters
@@ -535,7 +535,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
535535 indicating which rows belong to the split `i`.
536536 network_type : callable, optional, default: `InvariantNetowk`
537537 Type of neural network to use.
538- meta : dict, optional, default: {}
538+ network_kwargs : dict, optional, default: {}
539539 A dictionary containing the configuration for the networks.
540540 **kwargs
541541 Optional keyword arguments to be passed to the `tf.keras.Model` superclass.
@@ -547,7 +547,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
547547 self .split_data_configurator = split_data_configurator
548548 self .networks = [network_type (** network_kwargs ) for _ in range (num_splits )]
549549
550- def call (self , x ):
550+ def call (self , x , ** kwargs ):
551551 """Performs a forward pass through the subnetworks and concatenates their output.
552552
553553 Parameters
@@ -561,7 +561,7 @@ def call(self, x):
561561 Output of shape (batch_size, out_dim)
562562 """
563563
564- out = [self .networks [i ](self .split_data_configurator (i , x )) for i in range (self .num_splits )]
564+ out = [self .networks [i ](self .split_data_configurator (i , x ), ** kwargs ) for i in range (self .num_splits )]
565565 out = tf .concat (out , axis = - 1 )
566566 return out
567567
@@ -602,7 +602,7 @@ def call(self, x, return_all=False, **kwargs):
602602
603603 Parameters
604604 ----------
605- data : tf.Tensor of shape (batch_size, ..., data_dim)
605+ x : tf.Tensor of shape (batch_size, ..., data_dim)
606606 Example, hierarchical data sets with two levels:
607607 (batch_size, D, L, x_dim) -> reduces to (batch_size, out_dim).
608608 return_all : boolean, optional, default: False
0 commit comments