|
| 1 | +import jax |
| 2 | +import numpy as np |
| 3 | +from ngclearn.utils.analysis.probe import Probe |
| 4 | +from ngclearn.utils.model_utils import drop_out, softmax, gelu, layer_normalize |
| 5 | +from ngclearn.utils.optim import adam |
| 6 | +from jax import jit, random, numpy as jnp, lax, nn |
| 7 | +from functools import partial as bind |
| 8 | + |
| 9 | +def masked_fill(x: jax.Array, mask: jax.Array, value=0) -> jax.Array: |
| 10 | + """ |
| 11 | + Return an output with masked condition, with non-masked value |
| 12 | + be the other value |
| 13 | +
|
| 14 | + Args: |
| 15 | + x (jax.Array): _description_ |
| 16 | + mask (jax.Array): _description_ |
| 17 | + value (int, optional): _description_. Defaults to 0. |
| 18 | +
|
| 19 | + Returns: |
| 20 | + jax.Array: _description_ |
| 21 | + """ |
| 22 | + return jnp.where(mask, jnp.broadcast_to(value, x.shape), x) |
| 23 | + |
| 24 | +@bind(jax.jit, static_argnums=[4, 5]) |
| 25 | +def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array, n_heads: int=8, dropout_rate: float=0.0): |
| 26 | + B, T, Dq = x1.shape # The original shape |
| 27 | + _, S, Dkv = x2.shape |
| 28 | + # in here we attend x2 to x1 |
| 29 | + Wq, bq, Wk, bk, Wv, bv, Wout, bout = params |
| 30 | + # projection |
| 31 | + q = x1 @ Wq + bq # normal linear transformation (B, T, D) |
| 32 | + k = x2 @ Wk + bk # normal linear transformation (B, S, D) |
| 33 | + v = x2 @ Wv + bv # normal linear transformation (B, S, D) |
| 34 | + hidden = q.shape[-1] |
| 35 | + _hidden = hidden // n_heads |
| 36 | + q = q.reshape((B, T, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D) |
| 37 | + k = k.reshape((B, S, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D) |
| 38 | + v = v.reshape((B, S, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D) |
| 39 | + score = jnp.einsum("BHTE,BHSE->BHTS", q, k) / jnp.sqrt(_hidden) # Q @ KT / ||d||; d = D // n_heads |
| 40 | + if mask is not None: |
| 41 | + Tq, Tk = q.shape[2], k.shape[2] |
| 42 | + assert mask.shape == (B, Tq, Tk), (mask.shape, (B, Tq, Tk)) |
| 43 | + _mask = mask.reshape((B, 1, Tq, Tk)) # 'b tq tk -> b 1 tq tk' |
| 44 | + score = masked_fill(score, _mask, value=-jnp.inf) # basically masking out all must-unattended values |
| 45 | + score = jax.nn.softmax(score, axis=-1) # (B, H, T, S) |
| 46 | + score = score.astype(q.dtype) # (B, H, T, S) |
| 47 | + if dropout_rate > 0.: |
| 48 | + score = drop_out(input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here |
| 49 | + attention = jnp.einsum("BHTS,BHSE->BHTE", score, v) # (B, T, H, E) |
| 50 | + attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D) |
| 51 | + return attention @ Wout + bout # (B, T, Dq) |
| 52 | + |
| 53 | +@bind(jax.jit, static_argnums=[3, 4, 5, 6]) |
| 54 | +def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_softmax=True): |
| 55 | + # encoded_image_feature: (B, hw, dim) |
| 56 | + #learnable_query, *_params) = params |
| 57 | + learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout, Whid, bhid, Wln_mu, Wln_scale, Wy, by = params |
| 58 | + attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout) |
| 59 | + features = cross_attention(attn_params, learnable_query, encodings, mask, n_heads, dropout) |
| 60 | + features = features[:, 0] # (B, 1, dim) => (B, dim) |
| 61 | + hids = jnp.matmul((features + learnable_query[:, 0]), Whid) + bhid |
| 62 | + hids = gelu(hids) |
| 63 | + if use_LN: ## normalize hidden layer output of probe predictor |
| 64 | + hids = layer_normalize(hids, Wln_mu, Wln_scale) |
| 65 | + outs = jnp.matmul(hids, Wy) + by |
| 66 | + if use_softmax: ## apply softmax output nonlinearity |
| 67 | + outs = softmax(outs) |
| 68 | + return outs, features |
| 69 | + |
| 70 | +@bind(jax.jit, static_argnums=[4, 5, 6, 7]) |
| 71 | +def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_softmax=True): |
| 72 | + # encodings: (B, hw, dim) |
| 73 | + outs, _ = run_attention_probe(params, encodings, mask, n_heads, dropout, use_LN, use_softmax) |
| 74 | + if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions |
| 75 | + L = -jnp.mean(jnp.sum(jnp.log(outs) * labels, axis=1, keepdims=True)) |
| 76 | + else: ## MSE for real-valued outputs |
| 77 | + L = jnp.mean(jnp.sum(jnp.square(outs - labels), axis=1, keepdims=True)) |
| 78 | + return L, outs #, features |
| 79 | + |
| 80 | +class AttentiveProbe(Probe): |
| 81 | + """ |
| 82 | + Args: |
| 83 | + dkey: init seed key |
| 84 | +
|
| 85 | + source_seq_length: length of input sequence (e.g., height x width of the image feature) |
| 86 | +
|
| 87 | + input_dim: input dimensionality of probe |
| 88 | +
|
| 89 | + out_dim: output dimensionality of probe |
| 90 | +
|
| 91 | + num_heads: number of cross-attention heads |
| 92 | +
|
| 93 | + head_dim: output dimensionality of each cross-attention head |
| 94 | +
|
| 95 | + target_seq_length: to pool, we set it at one (or map the source sequence to the target sequence of length 1) |
| 96 | +
|
| 97 | + learnable_query_dim: target sequence dim (output dimension of cross-attention portion of probe) |
| 98 | +
|
| 99 | + batch_size: size of batches to process per internal call to update (or process) |
| 100 | +
|
| 101 | + hid_dim: dimensionality of hidden layer(s) of MLP portion of probe |
| 102 | +
|
| 103 | + use_LN: should layer normalization be used within MLP portions of probe or not? |
| 104 | +
|
| 105 | + use_softmax: should a softmax be applied to output of probe or not? |
| 106 | +
|
| 107 | + """ |
| 108 | + def __init__( |
| 109 | + self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, head_dim=64, |
| 110 | + target_seq_length=1, learnable_query_dim=31, batch_size=1, hid_dim=32, use_LN=True, use_softmax=True, **kwargs |
| 111 | + ): |
| 112 | + super().__init__(dkey, batch_size, **kwargs) |
| 113 | + self.dkey, *subkeys = random.split(self.dkey, 12) |
| 114 | + self.num_heads = num_heads |
| 115 | + self.source_seq_length = source_seq_length |
| 116 | + self.input_dim = input_dim |
| 117 | + self.out_dim = out_dim |
| 118 | + self.use_softmax = use_softmax |
| 119 | + self.use_LN = use_LN |
| 120 | + |
| 121 | + sigma = 0.05 |
| 122 | + ## cross-attention parameters |
| 123 | + Wq = random.normal(subkeys[0], (learnable_query_dim, head_dim)) * sigma |
| 124 | + bq = random.normal(subkeys[1], (1, head_dim)) * sigma |
| 125 | + Wk = random.normal(subkeys[2], (input_dim, head_dim)) * sigma |
| 126 | + bk = random.normal(subkeys[3], (1, head_dim)) * sigma |
| 127 | + Wv = random.normal(subkeys[4], (input_dim, head_dim)) * sigma |
| 128 | + bv = random.normal(subkeys[5], (1, head_dim)) * sigma |
| 129 | + Wout = random.normal(subkeys[6], (head_dim, learnable_query_dim)) * sigma |
| 130 | + bout = random.normal(subkeys[7], (1, learnable_query_dim)) * sigma |
| 131 | + #params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout) |
| 132 | + learnable_query = jnp.zeros((batch_size, 1, learnable_query_dim)) # (B, T, D) |
| 133 | + #self.all_params = (learnable_query, *params) |
| 134 | + self.mask = np.zeros((batch_size, target_seq_length, source_seq_length)).astype(bool) ## mask tensor |
| 135 | + ## MLP parameters |
| 136 | + Whid = random.normal(subkeys[8], (learnable_query_dim, hid_dim)) * sigma |
| 137 | + bhid = random.normal(subkeys[9], (1, hid_dim)) * sigma |
| 138 | + Wln_mu = jnp.zeros((1, hid_dim)) |
| 139 | + Wln_scale = jnp.ones((1, hid_dim)) |
| 140 | + Wy = random.normal(subkeys[8], (hid_dim, out_dim)) * sigma |
| 141 | + by = random.normal(subkeys[9], (1, out_dim)) * sigma |
| 142 | + #mlp_params = (Whid, bhid, Wln_mu, Wln_scale, Wy, by) |
| 143 | + self.probe_params = (learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout, Whid, bhid, Wln_mu, Wln_scale, Wy, by) |
| 144 | + |
| 145 | + ## set up gradient calculator |
| 146 | + self.grad_fx = jax.value_and_grad(eval_attention_probe, argnums=0, has_aux=True) |
| 147 | + ## set up update rule/optimizer |
| 148 | + self.optim_params = adam.adam_init(self.probe_params) |
| 149 | + self.eta = 0.001 |
| 150 | + |
| 151 | + def process(self, embedding_sequence): |
| 152 | + outs, feats = run_attention_probe( |
| 153 | + self.probe_params, embedding_sequence, self.mask, self.num_heads, 0.0, use_LN=self.use_LN, |
| 154 | + use_softmax=self.use_softmax |
| 155 | + ) |
| 156 | + return outs |
| 157 | + |
| 158 | + def update(self, embedding_sequence, labels): |
| 159 | + ## compute partial derivatives / adjustments to probe parameters |
| 160 | + outputs, grads = self.grad_fx( |
| 161 | + self.probe_params, embedding_sequence, labels, self.mask, self.num_heads, dropout=0., use_LN=self.use_LN, |
| 162 | + use_softmax=self.use_softmax |
| 163 | + ) |
| 164 | + loss, predictions = outputs |
| 165 | + ## adjust parameters of probe |
| 166 | + self.optim_params, self.probe_params = adam.adam_step( |
| 167 | + self.optim_params, self.probe_params, grads, eta=self.eta |
| 168 | + ) |
| 169 | + return loss, predictions |
| 170 | + |
0 commit comments