@@ -11,11 +11,12 @@ class Probe():
1111
1212 """
1313 def __init__ (
14- self , dkey , batch_size = 4 , ** kwargs
14+ self , dkey , batch_size = 1 , dev_batch_size = 1 , ** kwargs
1515 ):
1616 #dkey, *subkeys = random.split(dkey, 3)
1717 self .dkey = dkey
1818 self .batch_size = batch_size
19+ self .dev_batch_size = dev_batch_size
1920
2021 def process (self , embeddings ):
2122 predictions = None
@@ -25,24 +26,29 @@ def update(self, embeddings, labels):
2526 L = predictions = None
2627 return L , predictions
2728
28- def predict (self , data ):
29+ def predict (self , data , batch_size = None ):
2930 """
3031 Runs this probe's inference scheme over a pool of data.
3132
3233 Args:
3334 data: a dataset or design tensor/matrix containing encoding vector sequences; shape (N, T, embed_dim) or (N, embed_dim)
3435
36+ batch_size: optional batch-size argument (Default: None, will use training batch size)
37+
3538 Returns:
3639 the output scores/predictions made by this probe
3740 """
41+ _batch_size = batch_size
42+ if _batch_size is None :
43+ _batch_size = self .batch_size
3844 _data = data
3945 if len (_data .shape ) < 3 :
4046 _data = jnp .expand_dims (_data , axis = 1 )
4147
4248 n_samples , seq_len , dim = _data .shape
43- n_batches = int (n_samples / self . batch_size )
49+ n_batches = int (n_samples / _batch_size )
4450 s_ptr = 0
45- e_ptr = self . batch_size
51+ e_ptr = _batch_size
4652 Y_mu = []
4753 for b in range (n_batches ):
4854 x_mb = _data [s_ptr :e_ptr , :, :] ## slice out 3D batch tensor
@@ -132,7 +138,7 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
132138
133139 impatience += 1
134140 if dev_data is not None :
135- Ymu = self .predict (dev_data )
141+ Ymu = self .predict (dev_data , batch_size = self . dev_batch_size )
136142 acc = jnp .sum (jnp .equal (jnp .argmax (Ymu , axis = 1 ), jnp .argmax (dev_labels , axis = 1 ))) / (dev_labels .shape [0 ] * 1. )
137143 if acc > best_acc :
138144 best_acc = acc
0 commit comments