diff --git a/src/gluonnlp/attention_cell.py b/src/gluonnlp/attention_cell.py index 4773f81d46..9a9ef1576f 100644 --- a/src/gluonnlp/attention_cell.py +++ b/src/gluonnlp/attention_cell.py @@ -670,6 +670,163 @@ def __repr__(self): dtype=self._dtype) + + +def multi_head_sliding_window_dot_attn(F, query, key, value, dilation, valid_length, + window_size: int, symmetric: bool = True, + dropout: float = 0.0, scaled: bool = True, + normalized: bool = False, eps: float = 1E-6, + query_head_units: Optional[int] = None, + layout: str = 'NKT', + dtype=np.float32): + """Multihead sliding window attention between the query, key and value, + described at *Longformer: The Long-Document Transformer*, + available at https://arxiv.org/pdf/2004.05150.pdf. + + Given a fixed window size *2w*, each token attends to *w* tokens on the left side + if using causal attention (setting *symmetric* to *False*), + otherwise each token attends to *w* tokens on each side. + + Parameters + ---------- + F + query + Query. The shape is (batch_size, seq_length, num_heads, num_head_units) + + key + Key. The shape is (batch_size, seq_length, num_heads, num_head_units) + value + Value. The shape is (batch_size, seq_length, num_heads, num_head_units) + dilation + Dilation. The shape is (num_heads,) + valid_length + Valid length. The shape is (batch_size,) + window_size + The one-sided window length. + symmetric + If False, each token can only attend to itself and the previous tokens. + dropout + Dropout rate + scaled + Whether to divide the attention weights by the sqrt of the query dimension. + normalized + If turned on, the cosine distance is used, i.e:: + + score = + + eps + The epsilon value used in L2 normalization + query_head_units + The units of each query head. If it's empty, we will estimate it via the + shape_array of the query. + layout + This stands for the layout of the attention cell. The shape of the input/output will depend + on the layout. Currently, we only support 'NTK' in which + 'N' means the batch_size, 'K' means the head, and 'T' means the length dimension. + + Returns + ------- + context_vec + - (batch_size, seq_length, num_heads, num_head_units) + additional_info + scores: + Shape (batch_size, num_heads, seq_length, w + w + 1) if *symmetric* is True + Shape (batch_size, num_heads, seq_length, w + 1) otherwise + attn_weight: + Shape (batch_size, num_heads, seq_length, w + w + 1) if *symmetric* is True + Shape (batch_size, num_heads, seq_length, w + 1) otherwise + """ + if layout != "NTK": + raise NotImplementedError('We only support layout = "NTK".') + if normalized: + query = l2_normalize(F, query, axis=-1, eps=eps) + key = l2_normalize(F, key, axis=-1, eps=eps) + # 1. Calculate the attention weights + # scores' shape (batch_size, seq_length, num_heads, w + w + 1) if symmetric else + # (batch_size, seq_length, num_heads, w + 1) + scores = F.npx.sldwin_atten_score(query, key, dilation, + w=window_size, symmetric=symmetric) + if scaled: + if query_head_units is None: + query_shape = F.npx.shape_array(query) + scores = scores / F.np.sqrt(query_shape[-1]) + else: + scores = scores / math.sqrt(query_head_units) + # mask's shape is the same as scores + mask = F.npx.sldwin_atten_mask_like(scores, dilation, valid_length.astype(np.int32), + w=window_size, symmetric=symmetric) + attn_weights = masked_softmax(F, scores, mask, dtype=dtype) + attn_weights = F.npx.dropout(attn_weights, p=dropout) + # 2. Calculate the context vector + # (batch_size, seq_length, num_heads, num_head_units) + context_vec = F.npx.sldwin_atten_context(attn_weights, value, dilation, + w=window_size, symmetric=symmetric) + # (batch_size, seq_length, num_units) + context_vec = F.npx.reshape(context_vec, (-2, -2, -1)) + + return context_vec, [scores, attn_weights] + + +class MultiHeadSlidingWindowAttentionCell(HybridBlock): + def __init__(self, window_size, symmetric=True, query_units=None, num_heads=None, + attention_dropout=0.0, scaled: bool = True, normalized: bool = False, + eps: float = 1E-6, dtype='float32', layout='NTK'): + super().__init__() + self._query_units = query_units + self._window_size = window_size + self._symmetric = symmetric + self._num_heads = num_heads + self._attention_dropout = attention_dropout + self._scaled = scaled + self._normalized = normalized + self._eps = eps + self._dtype = dtype + self._layout = layout + if self._query_units is not None: + assert self._num_heads is not None + assert self._query_units % self._num_heads == 0,\ + 'The units must be divisible by the number of heads.' + self._query_head_units = self._query_units // self._num_heads + else: + self._query_head_units = None + + @property + def layout(self): + return self._layout + + def hybrid_forward(self, F, query, key, value, dilation, valid_length): + return multi_head_sliding_window_dot_attn(F, query=query, key=key, + value=value, dilation=dilation, + valid_length=valid_length, window_size=self._window_size, + symmetric=self._symmetric, dropout=self._attention_dropout, + scaled=self._scaled, normalized=self._normalized, eps=self._eps, + query_head_units=self._query_head_units, layout=self._layout, + dtype=self._dtype) + + def __repr__(self): + s = '{name}(\n' \ + ' window_size={window_size},\n' \ + ' symmetric={symmetric},\n' \ + ' query_units={query_units},\n' \ + ' num_heads={num_heads},\n' \ + ' attention_dropout={attention_dropout},\n' \ + ' scaled={scaled},\n' \ + ' normalized={normalized},\n' \ + ' layout="{layout}",\n' \ + ' dtype={dtype}\n' \ + ')' + return s.format(name=self.__class__.__name__, + window_size=self._window_size, + symmetric=self._symmetric, + query_units=self._query_units, + num_heads=self._num_heads, + attention_dropout=self._attention_dropout, + scaled=self._scaled, + normalized=self._normalized, + layout=self._layout, + dtype=self._dtype) + + class RelAttentionScoreCell(HybridBlock): """Get the score based on the query and relative position index. This is used for implementing relative attention. diff --git a/tests/test_attention_cell.py b/tests/test_attention_cell.py index c7166f19c9..7592671c09 100644 --- a/tests/test_attention_cell.py +++ b/tests/test_attention_cell.py @@ -6,7 +6,8 @@ from gluonnlp.attention_cell import\ multi_head_dot_attn, gen_self_attn_mask, gen_mem_attn_mask,\ MultiHeadAttentionCell,\ - RelAttentionScoreCell + RelAttentionScoreCell,\ + MultiHeadSlidingWindowAttentionCell from gluonnlp.utils.parameter import grad_global_norm mx.npx.set_np() @@ -388,3 +389,77 @@ def test_multi_head_rel_attn_score(num_heads, method, bidirectional, hybridize, assert_allclose(rel_score.asnumpy(), original_rel_score, 1E-5, 1E-5) layout_query_grad_norm = np.linalg.norm(query.grad.asnumpy()) assert_allclose(layout_query_grad_norm, original_query_grad_norm, 1E-5, 1E-5) + + + +def test_multi_head_sliding_window_dot_attention_cell(): + + def gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d): + """Generate sliding_window attention mask for the full attention matrix ( seq_len^2 ). + """ + mask_np = np.zeros((batch_size, seq_length, seq_length)) + for i in range(seq_length): + end = (i + 1 + w * d) if symmetric else (i + 1) + for j in range(i - w * d, end, d): + if j >= 0 and j < seq_length: + mask_np[:, i, j] = 1 + return mask_np + + def test_impl(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d): + attn_cell = MultiHeadAttentionCell() + sw_attn_cell = MultiHeadSlidingWindowAttentionCell(w, symmetric) + # Generate the data + query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units)) + key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units)) + value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units)) + mask = gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d) + mask = mx.np.array(mask, dtype=np.float32) + + query = mx.np.array(query, dtype=np.float32) + key = mx.np.array(key, dtype=np.float32) + value = mx.np.array(value, dtype=np.float32) + + query.attach_grad() + key.attach_grad() + value.attach_grad() + + with mx.autograd.record(): + out, _ = attn_cell(query, key, value, mask) + out.backward() + + out_np = out.asnumpy() + grad_query = query.grad.asnumpy() + grad_key = key.grad.asnumpy() + grad_value = value.grad.asnumpy() + + query.grad[:] = 0 + key.grad[:] = 0 + value.grad[:] = 0 + + dilation = mx.np.zeros((num_heads,)) + dilation[:] = d + dilation = mx.np.array(dilation, dtype=np.int32) + valid_length = np.zeros((batch_size,)) + valid_length[:] = seq_length + valid_length = mx.np.array(valid_length, dtype=np.int32) + + with mx.autograd.record(): + sw_out, _ = sw_attn_cell(query, key, value, dilation, valid_length) + sw_out.backward() + + sw_out_np = sw_out.asnumpy() + sw_grad_query = query.grad.asnumpy() + sw_grad_key = key.grad.asnumpy() + sw_grad_value = value.grad.asnumpy() + + assert_allclose(sw_out_np, out_np, 1E-3, 1E-3) + assert_allclose(sw_grad_key, grad_key, 1E-3, 1E-3) + assert_allclose(sw_grad_value, grad_value, 1E-3, 1E-3) + assert_allclose(sw_grad_query, grad_query, 1E-3, 1E-3) + + for symmetric in [True, False]: + for d in [1, 2, 3]: + test_impl(4, 128, 12, 64, 16, symmetric, d) + test_impl(1, 8, 2, 3, 2, symmetric, d) + +