Skip to content
304 changes: 299 additions & 5 deletions opennmt/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ def matmul_with_relative_representations(a, b, transpose_b=False):
return c


def calculate_attn(dot, values, dropout, training):
attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
drop_attn = common.dropout(attn, dropout, training=training)
heads = tf.matmul(drop_attn, values)
return heads, attn, drop_attn


class FeedForwardNetwork(tf.keras.layers.Layer):
"""Implements the Transformer's "Feed Forward" layer.

Expand Down Expand Up @@ -215,6 +222,9 @@ def __init__(
dropout=0.1,
return_attention=False,
maximum_relative_position=None,
max_length_full_attention=None,
local_attention_radius=0,
global_attention_length=0,
**kwargs
):
"""Initializes this layer.
Expand All @@ -227,6 +237,11 @@ def __init__(
return_attention: If ``True``, also return the attention weights.
maximum_relative_position: Maximum relative position representation
(from https://arxiv.org/abs/1803.02155).
max_length_full_attention: Maximum sequence length for full attention.
If not ``None``, use sparse attention for longer sequences
(from https://arxiv.org/abs/2004.08483).
local_attention_radius: Attention radius around each token for local sliding attention.
global_attention_length: Number of tokens used for global attention with sparse attention.
kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
Expand All @@ -244,6 +259,9 @@ def __init__(
self.dropout = dropout
self.return_attention = return_attention
self.maximum_relative_position = maximum_relative_position
self.max_length_full_attention = max_length_full_attention
self.local_attention_radius = local_attention_radius
self.global_attention_length = global_attention_length

def map_v1_weights(self, weights):
# V1 used conv1d layers that have a leading dimensions.
Expand Down Expand Up @@ -356,24 +374,102 @@ def _compute_kv(x):

cache = (keys, values)

queries_length = misc.shape_list(queries)[2]

if self.max_length_full_attention is not None:
if memory is not None:
raise ValueError("Sparse attention only supports self-attention.")
if self.maximum_relative_position is not None:
raise ValueError("Sparse attention doesn't support relative positions.")
if self.return_attention:
raise ValueError(
"Cannot return attention weights when using sparse attention."
)

if self.max_length_full_attention is not None:
use_sparse_att = (
tf.less(self.max_length_full_attention, queries_length)
& tf.less(0, queries_length)
& tf.less(self.global_attention_length, queries_length)
)
if self.global_attention_length > 0:
global_keys = keys
global_values = values
global_queries = queries
if use_sparse_att:
queries = queries[:, :, self.global_attention_length :, :]
global_queries = queries[:, :, : self.global_attention_length, :]
queries, keys, values, num_chunks = tf.cond(
use_sparse_att,
lambda: split_qkv(
queries,
keys,
values,
self.local_attention_radius,
self.global_attention_length,
),
lambda: (queries, keys, values, 0),
)
# Dot product attention.
dot = tf.matmul(queries, keys, transpose_b=True)
if relative_repr_keys is not None:
dot += matmul_with_relative_representations(
queries, relative_repr_keys, transpose_b=True
)

if (
self.max_length_full_attention is not None
and self.global_attention_length > 0
):
global_dot = global_queries
if use_sparse_att:
global_dot = tf.matmul(global_queries, global_keys, transpose_b=True)

if mask is not None:
mask = tf.cast(mask, dot.dtype)
if mask.shape.rank == 2:
mask = tf.expand_dims(mask, 1) # Broadcast on time dimension.
if self.max_length_full_attention is not None:
if self.global_attention_length > 0:
global_mask = mask
if use_sparse_att:
if mask.shape.rank == 2:
global_mask = mask[:, tf.newaxis, :]
else:
global_mask = mask[:, : self.global_attention_length, :]
global_mask = global_mask[:, tf.newaxis, :, :]
global_dot = (global_dot * global_mask) + (
1.0 - global_mask
) * global_dot.dtype.min
mask = tf.cond(
use_sparse_att,
lambda: chunk_att_mask(
mask, self.local_attention_radius, self.global_attention_length
),
lambda: tf.expand_dims(mask, 1) if mask.shape.rank == 2 else mask,
)
elif mask.shape.rank == 2:
mask = tf.expand_dims(mask, 1)
mask = tf.expand_dims(mask, 1) # Broadcast on head dimension.
dot = (dot * mask) + (1.0 - mask) * dot.dtype.min

attn = tf.nn.softmax(dot)
drop_attn = common.dropout(attn, self.dropout, training=training)
heads, attn, drop_attn = calculate_attn(dot, values, self.dropout, training)

if self.max_length_full_attention is not None:
if self.global_attention_length > 0:
global_heads = heads
global_attn = attn
if use_sparse_att:
global_heads, global_attn, _ = calculate_attn(
global_dot, global_values, self.dropout, training
)

heads = tf.cond(
use_sparse_att,
lambda: combine_chunks(
heads, num_chunks, queries_length - self.global_attention_length
),
lambda: heads,
)

heads = tf.matmul(drop_attn, values)
if relative_repr_values is not None:
heads += matmul_with_relative_representations(
drop_attn, relative_repr_values
Expand All @@ -382,6 +478,18 @@ def _compute_kv(x):
# Concatenate all heads output.
combined = combine_heads(heads)
outputs = self.linear_output(combined)

if (
self.max_length_full_attention is not None
and self.global_attention_length > 0
):
global_combined = combined
global_outputs = outputs
if use_sparse_att:
global_combined = combine_heads(global_heads)
global_outputs = self.linear_output(global_combined)
outputs = tf.concat((global_outputs, outputs), axis=1)

if self.return_attention:
return outputs, cache, attn
return outputs, cache
Expand Down Expand Up @@ -610,3 +718,189 @@ def call(
outputs = self.ffn(outputs, training=training)
cache = dict(self_kv=self_kv, memory_kv=memory_kv)
return outputs, cache, attention


def split_chunks(a, chunk_length, concat_3_chunks=True, global_length=0):
"""Splits a tensor into chunks along the timesteps axis.

Args:
a: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`.
chunk_length: The length of a chunk :math:`C`.
concat_3_chunks: Optional, if ``True``, append previous and following chunks to each chunk.

Returns:
A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number.
"""

if global_length:
global_a = a[:, :, :global_length, :]
a = a[:, :, global_length:, :]

batch, num_heads, timesteps, units_per_head = misc.shape_list(a)

# Pad to a factor of chunk_length.
pad_len = -timesteps % chunk_length
# batch, num_heads, timesteps padded, units_per_head
a_padded = tf.pad(tensor=a, paddings=[[0, 0], [0, 0], [0, pad_len], [0, 0]])
padded_len = misc.shape_list(a_padded)[2]

# Chunk along timesteps axis.
num_chunks = padded_len // chunk_length
chunked_shape = [batch, num_heads, num_chunks, chunk_length, units_per_head]
# batch, num_heads, num_chunks, chunk_length, units_per_head
a_chunked = tf.reshape(a_padded, chunked_shape)

# Concatenate previous and next chunk to each chunk, for overlapping.
if concat_3_chunks:
# batch, num_heads, 1 + num_chunks + 1, chunk_length, units_per_head
a_chunked_padded = tf.pad(
a_chunked, paddings=[[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]]
)
# batch, num_heads, num_chunks, chunk_length*3, units_per_head
a_chunked = tf.concat(
[a_chunked_padded[:, :, i : (i + num_chunks), ...] for i in range(3)], 3
)

# Transpose and flatten first dimension (batch * num_chunks).
# batch, num_chunks, num_heads, chunk_length (*3), units_per_head
a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4])

if global_length:
# batch, num_chunks, num_heads, global timesteps, units_per_head
expanded_global_a = tf.tile(
tf.expand_dims(global_a, 1), [1, num_chunks, 1, 1, 1]
)
a_transposed = tf.concat([a_transposed, expanded_global_a], axis=3)

input_shape = misc.shape_list(a_transposed)
output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0)
# batch x num_chunks, num_heads, chunk_length (*3) + global_length, units_per_head
return tf.reshape(a_transposed, output_shape), num_chunks


def split_qkv(queries, keys, values, chunk_length, global_attention_length):

# batch x num_chunks, num_heads, chunk_length, units_per_head
queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False)
# batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head
keys, _ = split_chunks(keys, chunk_length, global_length=global_attention_length)
# batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head
values, num_chunks = split_chunks(
values, chunk_length, global_length=global_attention_length
)

return queries, keys, values, num_chunks


def combine_chunks(a, num_chunks, unchunked_length):
# Unchunk
a_shape = misc.shape_list(a)
# batch, num_chunks, num_heads, chunk_length, self.num_units_per_head
a = tf.reshape(
a,
[
a_shape[0] // num_chunks,
num_chunks,
a_shape[1],
a_shape[2],
a_shape[3],
],
)
# batch, num_heads, num_chunks, chunk_length, self.num_units_per_head
a = tf.transpose(a, perm=[0, 2, 1, 3, 4])
a_shape = misc.shape_list(a)
a = tf.reshape(
a,
[
a_shape[0],
a_shape[1],
a_shape[2] * a_shape[3],
a_shape[4],
],
)

# Remove padding used for chunking.
return a[:, :, :unchunked_length, :]


def chunk_att_mask(mask, chunk_length, global_length=0):
"""Transforms an attention mask into a chunked representation.

Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``.

Args:
mask: A ``tf.Tensor`` of shape :math:`[B, T]` or :math:`[B, T, T]`.
chunk_length: The length of a chunk :math:`C`.

Returns:
A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks.
"""

mask_shape = misc.shape_list(mask)
batch = mask_shape[0]
timesteps = mask_shape[-1]
rank = len(mask_shape)

if rank == 2:
# Broadcast on queries time dimension.
mask = tf.expand_dims(mask, 1)
mask = tf.broadcast_to(mask, [batch, timesteps, timesteps])

if global_length:
global_mask = mask[:, global_length:, :global_length]
mask = mask[:, global_length:, global_length:]
timesteps = timesteps - global_length

# Pad to a factor of chunk_length.
pad_len = -timesteps % chunk_length
mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]])
if global_length:
global_mask = tf.pad(
tensor=global_mask, paddings=[[0, 0], [0, pad_len], [0, 0]]
)
padded_timesteps = misc.shape_list(mask)[-1]

# Append chunk_length padding to timestep axis, before and after.
mask_padded = tf.pad(
tensor=mask, paddings=[[0, 0], [0, 0], [chunk_length, chunk_length]]
)
padded_len = misc.shape_list(mask_padded)[-1]
mask_flattened = tf.reshape(mask_padded, shape=[batch, -1])

# Skew to the left by one and keep 2*chunk_length + 1 relevant locations.
# This corresponds to chunk_length radius around the diagonal.
skewed_len = padded_len + 1
skewed_padding_len = (
padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1]
)
mask_padded = tf.pad(mask_flattened, paddings=[[0, 0], [0, skewed_padding_len]])
skewed_shape = [batch, -1, skewed_len]
mask_skewed = tf.reshape(mask_padded, shape=skewed_shape)
mask_skewed = mask_skewed[:, :, : chunk_length * 2 + 1]

chunk_num = padded_timesteps // chunk_length
mask_skewed_chunked = tf.reshape(mask_skewed, [batch, chunk_num, chunk_length, -1])

# Unskew each chunk to be compatible with chunked attention shape.
unskewed_len = chunk_length * 3
mask_skewed_padded = tf.pad(
mask_skewed_chunked, paddings=[[0, 0], [0, 0], [0, 0], [0, chunk_length]]
)
mask_skewed_flattened = tf.reshape(mask_skewed_padded, shape=[batch, chunk_num, -1])
mask_skewed_flattened = mask_skewed_flattened[:, :, : (chunk_length * unskewed_len)]
mask_unskewed = tf.reshape(
mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3]
)

if global_length:
# batch, num_chunks, chunk_length, global_length
expanded_global_mask = tf.reshape(
global_mask, shape=[batch, chunk_num, chunk_length, global_length]
)
mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3)

# Flatten the first dimension to batch * chunk_num.
return tf.reshape(
mask_unskewed,
shape=[batch * chunk_num, chunk_length, chunk_length * 3 + global_length],
)
Loading