|
| 1 | +import jax |
| 2 | +import numpy as np |
| 3 | +from ngclearn.utils.analysis.probe import Probe |
| 4 | +from ngclearn.utils.model_utils import drop_out, softmax, gelu, layer_normalize |
| 5 | +from ngclearn.utils.optim import adam |
| 6 | +from jax import jit, random, numpy as jnp, lax, nn |
| 7 | +from functools import partial as bind |
| 8 | + |
| 9 | +def masked_fill(x: jax.Array, mask: jax.Array, value=0) -> jax.Array: |
| 10 | + """ |
| 11 | + Return an output with masked condition, with non-masked value |
| 12 | + be the other value |
| 13 | +
|
| 14 | + Args: |
| 15 | + x (jax.Array): _description_ |
| 16 | + mask (jax.Array): _description_ |
| 17 | + value (int, optional): _description_. Defaults to 0. |
| 18 | +
|
| 19 | + Returns: |
| 20 | + jax.Array: _description_ |
| 21 | + """ |
| 22 | + return jnp.where(mask, jnp.broadcast_to(value, x.shape), x) |
| 23 | + |
| 24 | +@bind(jax.jit, static_argnums=[5, 6]) |
| 25 | +def cross_attention(dkey, params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array, n_heads: int=8, dropout_rate: float=0.0) -> jax.Array: |
| 26 | + """ |
| 27 | + Run cross-attention function given a list of parameters and two sequences (x1 and x2). |
| 28 | + The function takes in a query sequence x1 and a key-value sequence x2, and returns an output of the same shape as x1. |
| 29 | + T is the length of the query sequence, and S is the length of the key-value sequence. |
| 30 | + Dq is the dimension of the query sequence, and Dkv is the dimension of the key-value sequence. |
| 31 | + H is the number of attention heads. |
| 32 | +
|
| 33 | + Args: |
| 34 | + dkey: JAX key to trigger any internal noise (drop-out) |
| 35 | +
|
| 36 | + params (tuple): tuple of parameters |
| 37 | +
|
| 38 | + x1 (jax.Array): query sequence. Shape: (B, T, Dq) |
| 39 | +
|
| 40 | + x2 (jax.Array): key-value sequence. Shape: (B, S, Dkv) |
| 41 | +
|
| 42 | + mask (jax.Array): mask tensor. Shape: (B, T, S) |
| 43 | +
|
| 44 | + n_heads (int, optional): number of attention heads. Defaults to 8. |
| 45 | +
|
| 46 | + dropout_rate (float, optional): dropout rate. Defaults to 0.0. |
| 47 | +
|
| 48 | + Returns: |
| 49 | + jax.Array: output of cross-attention |
| 50 | + """ |
| 51 | + B, T, Dq = x1.shape # The original shape |
| 52 | + _, S, Dkv = x2.shape |
| 53 | + # in here we attend x2 to x1 |
| 54 | + Wq, bq, Wk, bk, Wv, bv, Wout, bout = params |
| 55 | + # projection |
| 56 | + q = x1 @ Wq + bq # normal linear transformation (B, T, D) |
| 57 | + k = x2 @ Wk + bk # normal linear transformation (B, S, D) |
| 58 | + v = x2 @ Wv + bv # normal linear transformation (B, S, D) |
| 59 | + hidden = q.shape[-1] |
| 60 | + _hidden = hidden // n_heads |
| 61 | + q = q.reshape((B, T, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D) |
| 62 | + k = k.reshape((B, S, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D) |
| 63 | + v = v.reshape((B, S, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D) |
| 64 | + score = jnp.einsum("BHTE,BHSE->BHTS", q, k) / jnp.sqrt(_hidden) # Q @ KT / ||d||; d = D // n_heads |
| 65 | + if mask is not None: |
| 66 | + Tq, Tk = q.shape[2], k.shape[2] |
| 67 | + assert mask.shape == (B, Tq, Tk), (mask.shape, (B, Tq, Tk)) |
| 68 | + _mask = mask.reshape((B, 1, Tq, Tk)) # 'b tq tk -> b 1 tq tk' |
| 69 | + score = masked_fill(score, _mask, value=-jnp.inf) # basically masking out all must-unattended values |
| 70 | + score = jax.nn.softmax(score, axis=-1) # (B, H, T, S) |
| 71 | + score = score.astype(q.dtype) # (B, H, T, S) |
| 72 | + if dropout_rate > 0.: |
| 73 | + score, _ = drop_out(dkey, score, rate=dropout_rate) ## NOTE: normally you apply dropout here |
| 74 | + attention = jnp.einsum("BHTS,BHSE->BHTE", score, v) # (B, T, H, E) |
| 75 | + attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D) |
| 76 | + return attention @ Wout + bout # (B, T, Dq) |
| 77 | + |
| 78 | +@bind(jax.jit, static_argnums=[4, 5, 6, 7, 8]) |
| 79 | +def run_attention_probe( |
| 80 | + dkey, params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, |
| 81 | + use_softmax=True |
| 82 | +): |
| 83 | + """ |
| 84 | + Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model). |
| 85 | +
|
| 86 | + Args: |
| 87 | + dkey: JAX key for any internal noise to be applied |
| 88 | +
|
| 89 | + params: parameters tuple/list of probe |
| 90 | +
|
| 91 | + encodings: input encoding vectors/data |
| 92 | +
|
| 93 | + mask: optional mask to be applied to internal cross-attention |
| 94 | +
|
| 95 | + n_heads: number of attention heads |
| 96 | +
|
| 97 | + dropout: if >0, triggers drop-out applied internally to cross-attention |
| 98 | +
|
| 99 | + use_LN: use layer normalization? |
| 100 | +
|
| 101 | + use_LN_input: use layer normalization on input encodings? |
| 102 | +
|
| 103 | + use_softmax: should softmax be applied to output of attention probe? (useful for classification) |
| 104 | +
|
| 105 | + Returns: |
| 106 | + output scores/probabilities, cross-attention (hidden) features |
| 107 | + """ |
| 108 | + # Two separate dkeys for each dropout in two cross attention |
| 109 | + dkey1, dkey2 = random.split(dkey, 2) |
| 110 | + # encoded_image_feature: (B, hw, dim) |
| 111 | + #learnable_query, *_params) = params |
| 112 | + learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout,\ |
| 113 | + Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu,\ |
| 114 | + Wlnattn_scale, Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2,\ |
| 115 | + bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3,\ |
| 116 | + Wy, by, ln_in_mu, ln_in_scale, ln_in_mu2, ln_in_scale2 = params |
| 117 | + cross_attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout) |
| 118 | + if use_LN_input: |
| 119 | + learnable_query = layer_normalize(learnable_query, ln_in_mu, ln_in_scale) |
| 120 | + encodings = layer_normalize(encodings, ln_in_mu2, ln_in_scale2) |
| 121 | + features = cross_attention(dkey1, cross_attn_params, learnable_query, encodings, mask, n_heads, dropout) |
| 122 | + # Perform a single self-attention block here |
| 123 | + # Self-Attention |
| 124 | + self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts) |
| 125 | + skip = features |
| 126 | + if use_LN: |
| 127 | + features = layer_normalize(features, Wlnattn_mu, Wlnattn_scale) |
| 128 | + features = cross_attention(dkey2, self_attn_params, features, features, None, n_heads, dropout) |
| 129 | + features = features + skip |
| 130 | + features = features[:, 0] # (B, 1, dim) => (B, dim) |
| 131 | + # MLP |
| 132 | + skip = features |
| 133 | + if use_LN: ## normalize hidden layer output of probe predictor |
| 134 | + features = layer_normalize(features, Wln_mu1, Wln_scale1) |
| 135 | + features = jnp.matmul((features), Whid1) + bhid1 |
| 136 | + features = gelu(features) |
| 137 | + if use_LN: ## normalize hidden layer output of probe predictor |
| 138 | + features = layer_normalize(features, Wln_mu2, Wln_scale2) |
| 139 | + features = jnp.matmul((features), Whid2) + bhid2 |
| 140 | + features = gelu(features) |
| 141 | + if use_LN: ## normalize hidden layer output of probe predictor |
| 142 | + features = layer_normalize(features, Wln_mu3, Wln_scale3) |
| 143 | + features = jnp.matmul((features), Whid3) + bhid3 |
| 144 | + features = features + skip |
| 145 | + outs = jnp.matmul(features, Wy) + by |
| 146 | + if use_softmax: ## apply softmax output nonlinearity |
| 147 | + # NOTE: Viet: please check the softmax function, it might potentially |
| 148 | + # cause the gradient to be nan since there is a potential division by zero |
| 149 | + outs = jax.nn.softmax(outs) |
| 150 | + return outs, features |
| 151 | + |
| 152 | +@bind(jax.jit, static_argnums=[5, 6, 7, 8, 9]) |
| 153 | +def eval_attention_probe(dkey, params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True): |
| 154 | + """ |
| 155 | + Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned |
| 156 | + labels/regression targets. |
| 157 | +
|
| 158 | + Args: |
| 159 | + dkey: JAX key to trigger any internal noise (as in drop-out) |
| 160 | +
|
| 161 | + params: parameters tuple/list of probe |
| 162 | +
|
| 163 | + encodings: input encoding vectors/data |
| 164 | +
|
| 165 | + labels: output target values (e.g., labels, regression target vectors) |
| 166 | +
|
| 167 | + mask: optional mask to be applied to internal cross-attention |
| 168 | +
|
| 169 | + n_heads: number of attention heads |
| 170 | +
|
| 171 | + dropout: if >0, triggers drop-out applied internally to cross-attention |
| 172 | +
|
| 173 | + use_LN: use layer normalization? |
| 174 | +
|
| 175 | + use_softmax: should softmax be applied to output of attention probe? (useful for classification) |
| 176 | +
|
| 177 | + Returns: |
| 178 | + current loss value, output scores/probabilities |
| 179 | + """ |
| 180 | + # encodings: (B, hw, dim) |
| 181 | + outs, _ = run_attention_probe(dkey, params, encodings, mask, n_heads, dropout, use_LN, use_LN_input, use_softmax) |
| 182 | + if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions |
| 183 | + L = -jnp.mean(jnp.sum(jnp.log(outs.clip(min=1e-5)) * labels, axis=1, keepdims=True)) |
| 184 | + else: ## MSE for real-valued outputs |
| 185 | + L = jnp.mean(jnp.sum(jnp.square(outs - labels), axis=1, keepdims=True)) |
| 186 | + return L, outs #, features |
| 187 | + |
| 188 | +class AttentiveProbe(Probe): |
| 189 | + """ |
| 190 | + This implements a nonlinear attentive probe, which is useful for evaluating the quality of |
| 191 | + encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot |
| 192 | + encodings or real-valued vector regression targets). |
| 193 | +
|
| 194 | + Args: |
| 195 | + dkey: init seed key |
| 196 | +
|
| 197 | + source_seq_length: length of input sequence (e.g., height x width of the image feature) |
| 198 | +
|
| 199 | + input_dim: input dimensionality of probe |
| 200 | +
|
| 201 | + out_dim: output dimensionality of probe |
| 202 | +
|
| 203 | + num_heads: number of cross-attention heads |
| 204 | +
|
| 205 | + head_dim: output dimensionality of each cross-attention head |
| 206 | +
|
| 207 | + target_seq_length: to pool, we set it at one (or map the source sequence to the target sequence of length 1) |
| 208 | +
|
| 209 | + learnable_query_dim: target sequence dim (output dimension of cross-attention portion of probe) |
| 210 | +
|
| 211 | + batch_size: size of batches to process per internal call to update (or process) |
| 212 | +
|
| 213 | + hid_dim: dimensionality of hidden layer(s) of MLP portion of probe |
| 214 | +
|
| 215 | + use_LN: should layer normalization be used within MLP portions of probe or not? |
| 216 | +
|
| 217 | + use_softmax: should a softmax be applied to output of probe or not? |
| 218 | +
|
| 219 | + """ |
| 220 | + def __init__( |
| 221 | + self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64, |
| 222 | + target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32, |
| 223 | + use_LN=True, use_LN_input=False, use_softmax=True, dropout=0.5, eta=0.0002, |
| 224 | + eta_decay=0.0, min_eta=1e-5, **kwargs |
| 225 | + ): |
| 226 | + super().__init__(dkey, batch_size, **kwargs) |
| 227 | + assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}." |
| 228 | + assert learnable_query_dim % num_heads == 0, f"`learnable_query_dim` must be divisible by `num_heads`. Got {learnable_query_dim} and {num_heads}." |
| 229 | + self.dkey, *subkeys = random.split(self.dkey, 26) |
| 230 | + self.num_heads = num_heads |
| 231 | + self.source_seq_length = source_seq_length |
| 232 | + self.input_dim = input_dim |
| 233 | + self.out_dim = out_dim |
| 234 | + self.use_softmax = use_softmax |
| 235 | + self.use_LN = use_LN |
| 236 | + self.use_LN_input = use_LN_input |
| 237 | + self.dropout = dropout |
| 238 | + |
| 239 | + sigma = 0.02 |
| 240 | + ## cross-attention parameters |
| 241 | + Wq = random.normal(subkeys[0], (learnable_query_dim, attn_dim)) * sigma |
| 242 | + bq = random.normal(subkeys[1], (1, attn_dim)) * sigma |
| 243 | + Wk = random.normal(subkeys[2], (input_dim, attn_dim)) * sigma |
| 244 | + bk = random.normal(subkeys[3], (1, attn_dim)) * sigma |
| 245 | + Wv = random.normal(subkeys[4], (input_dim, attn_dim)) * sigma |
| 246 | + bv = random.normal(subkeys[5], (1, attn_dim)) * sigma |
| 247 | + Wout = random.normal(subkeys[6], (attn_dim, learnable_query_dim)) * sigma |
| 248 | + bout = random.normal(subkeys[7], (1, learnable_query_dim)) * sigma |
| 249 | + cross_attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout) |
| 250 | + Wqs = random.normal(subkeys[8], (learnable_query_dim, learnable_query_dim)) * sigma |
| 251 | + bqs = random.normal(subkeys[9], (1, learnable_query_dim)) * sigma |
| 252 | + Wks = random.normal(subkeys[10], (learnable_query_dim, learnable_query_dim)) * sigma |
| 253 | + bks = random.normal(subkeys[11], (1, learnable_query_dim)) * sigma |
| 254 | + Wvs = random.normal(subkeys[12], (learnable_query_dim, learnable_query_dim)) * sigma |
| 255 | + bvs = random.normal(subkeys[13], (1, learnable_query_dim)) * sigma |
| 256 | + Wouts = random.normal(subkeys[14], (learnable_query_dim, learnable_query_dim)) * sigma |
| 257 | + bouts = random.normal(subkeys[15], (1, learnable_query_dim)) * sigma |
| 258 | + Wlnattn_mu = jnp.zeros((1, learnable_query_dim)) ## LN parameter (applied to output of attention) |
| 259 | + Wlnattn_scale = jnp.ones((1, learnable_query_dim)) ## LN parameter (applied to output of attention) |
| 260 | + self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu, Wlnattn_scale) |
| 261 | + learnable_query = jnp.zeros((batch_size, 1, learnable_query_dim)) # (B, T, D) |
| 262 | + self.mask = np.zeros((self.batch_size, target_seq_length, source_seq_length)).astype(bool) ## mask tensor |
| 263 | + self.dev_mask = np.zeros((self.dev_batch_size, target_seq_length, source_seq_length)).astype(bool) |
| 264 | + ## MLP parameters |
| 265 | + Whid1 = random.normal(subkeys[16], (learnable_query_dim, learnable_query_dim)) * sigma |
| 266 | + bhid1 = random.normal(subkeys[17], (1, learnable_query_dim)) * sigma |
| 267 | + Wln_mu1 = jnp.zeros((1, learnable_query_dim)) ## LN parameter |
| 268 | + Wln_scale1 = jnp.ones((1, learnable_query_dim)) ## LN parameter |
| 269 | + Whid2 = random.normal(subkeys[18], (learnable_query_dim, learnable_query_dim * 4)) * sigma |
| 270 | + bhid2 = random.normal(subkeys[19], (1, learnable_query_dim * 4)) * sigma |
| 271 | + Wln_mu2 = jnp.zeros((1, learnable_query_dim)) ## LN parameter |
| 272 | + Wln_scale2 = jnp.ones((1, learnable_query_dim)) ## LN parameter |
| 273 | + Whid3 = random.normal(subkeys[20], (learnable_query_dim * 4, learnable_query_dim)) * sigma |
| 274 | + bhid3 = random.normal(subkeys[21], (1, learnable_query_dim)) * sigma |
| 275 | + Wln_mu3 = jnp.zeros((1, learnable_query_dim * 4)) ## LN parameter |
| 276 | + Wln_scale3 = jnp.ones((1, learnable_query_dim * 4)) ## LN parameter |
| 277 | + Wy = random.normal(subkeys[22], (learnable_query_dim, out_dim)) * sigma |
| 278 | + by = random.normal(subkeys[23], (1, out_dim)) * sigma |
| 279 | + mlp_params = (Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2, bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3, Wy, by) |
| 280 | + # Finally, define ln for the input to the attention |
| 281 | + ln_in_mu = jnp.zeros((1, learnable_query_dim)) ## LN parameter |
| 282 | + ln_in_scale = jnp.ones((1, learnable_query_dim)) ## LN parameter |
| 283 | + ln_in_mu2 = jnp.zeros((1, input_dim)) ## LN parameter |
| 284 | + ln_in_scale2 = jnp.ones((1, input_dim)) ## LN parameter |
| 285 | + ln_in_params = (ln_in_mu, ln_in_scale, ln_in_mu2, ln_in_scale2) |
| 286 | + self.probe_params = (learnable_query, *cross_attn_params, *self_attn_params, *mlp_params, *ln_in_params) |
| 287 | + |
| 288 | + ## set up gradient calculator |
| 289 | + self.grad_fx = jax.value_and_grad(eval_attention_probe, argnums=1, has_aux=True) #, allow_int=True) |
| 290 | + ## set up update rule/optimizer |
| 291 | + self.optim_params = adam.adam_init(self.probe_params) |
| 292 | + # Learning rate scheduling |
| 293 | + self.eta = eta #0.001 |
| 294 | + self.eta_decay = eta_decay |
| 295 | + self.min_eta = min_eta |
| 296 | + |
| 297 | + # Finally, the dkey for the noise_key |
| 298 | + self.noise_key = subkeys[24] |
| 299 | + |
| 300 | + def process(self, embeddings, dkey=None): |
| 301 | + # noise_key = None |
| 302 | + noise_key = self.noise_key |
| 303 | + if dkey is not None: |
| 304 | + dkey, *subkeys = random.split(dkey, 2) |
| 305 | + noise_key = subkeys[0] |
| 306 | + outs, feats = run_attention_probe( |
| 307 | + noise_key, self.probe_params, embeddings, self.dev_mask, self.num_heads, 0.0, |
| 308 | + use_LN=self.use_LN, use_LN_input=self.use_LN_input, use_softmax=self.use_softmax |
| 309 | + ) |
| 310 | + return outs |
| 311 | + |
| 312 | + def update(self, embeddings, labels, dkey=None): |
| 313 | + # noise_key = None |
| 314 | + noise_key = self.noise_key |
| 315 | + if dkey is not None: |
| 316 | + dkey, *subkeys = random.split(dkey, 2) |
| 317 | + noise_key = subkeys[0] |
| 318 | + outputs, grads = self.grad_fx( |
| 319 | + noise_key, self.probe_params, embeddings, labels, self.mask, self.num_heads, dropout=self.dropout, |
| 320 | + use_LN=self.use_LN, use_LN_input=self.use_LN_input, use_softmax=self.use_softmax |
| 321 | + ) |
| 322 | + loss, predictions = outputs |
| 323 | + ## adjust parameters of probe |
| 324 | + self.optim_params, self.probe_params = adam.adam_step( |
| 325 | + self.optim_params, self.probe_params, grads, eta=self.eta |
| 326 | + ) |
| 327 | + |
| 328 | + self.eta = max(self.min_eta, self.eta - self.eta_decay * self.eta) |
| 329 | + return loss, predictions |
| 330 | + |
0 commit comments