@@ -191,8 +191,10 @@ def call(self, x, **kwargs):
191191 Output of shape (batch_size, set_size, input_dim)
192192 """
193193
194- batch_size = x .shape [0 ]
195- h = self .mab0 (tf .stack ([self .I ] * batch_size ), x , ** kwargs )
194+ batch_size = tf .shape (x )[0 ]
195+ I_expanded = self .I [None , ...]
196+ I_tiled = tf .tile (I_expanded , [batch_size , 1 , 1 ])
197+ h = self .mab0 (I_tiled , x , ** kwargs )
196198 return self .mab1 (x , h , ** kwargs )
197199
198200
@@ -240,7 +242,7 @@ def __init__(
240242 summary_dim , attention_settings , num_dense_fc , dense_settings , use_layer_norm , ** kwargs
241243 )
242244 init = tf .keras .initializers .GlorotUniform ()
243- self .seed_vec = init (shape = (num_seeds , summary_dim ))
245+ self .seed_vec = tf . Variable ( init (shape = (num_seeds , summary_dim )), name = "seed_vec" , trainable = True )
244246 self .fc = Sequential ([Dense (** dense_settings ) for _ in range (num_dense_fc )])
245247 self .fc .add (Dense (summary_dim ))
246248
@@ -258,7 +260,9 @@ def call(self, x, **kwargs):
258260 Output of shape (batch_size, num_seeds * summary_dim)
259261 """
260262
261- batch_size = x .shape [0 ]
262263 out = self .fc (x )
263- out = self .mab (tf .stack ([self .seed_vec ] * batch_size ), out , ** kwargs )
264- return tf .reshape (out , (out .shape [0 ], - 1 ))
264+ batch_size = tf .shape (x )[0 ]
265+ seed_expanded = self .seed_vec [None , ...]
266+ seed_tiled = tf .tile (seed_expanded , [batch_size , 1 , 1 ])
267+ out = self .mab (seed_tiled , out , ** kwargs )
268+ return tf .reshape (out , (tf .shape (out )[0 ], - 1 ))
0 commit comments