Skip to content
190 changes: 187 additions & 3 deletions opennmt/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,125 @@ def matmul_with_relative_representations(a, b, transpose_b=False):
return c


def split_chunks(a, chunk_length, concat_3_chunks=True):
"""Splits a tensor into chunks along the timesteps axis.

Args:
a: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`.
chunk_length: The length of a chunk :math:`C`.

Returns:
A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number.
"""

batch, num_heads, timesteps, units_per_head = misc.shape_list(a)

# Pad to a factor of chunk_length.
rank = a.shape.rank
timestep_axis = rank - 2
pad_len = -timesteps % chunk_length
paddings = pad_len * tf.one_hot([-1, timestep_axis], rank, axis=0, dtype=tf.int32)
# batch, num_heads, timesteps padded, units_per_head
a_padded = tf.pad(tensor=a, paddings=paddings)
padded_len = misc.shape_list(a_padded)[timestep_axis]

# Chunk along timesteps axis.
num_chunks = padded_len // chunk_length
chunked_shape = [batch, num_heads, num_chunks, chunk_length, units_per_head]
# batch, num_heads, num_chunks, chunk_length, units_per_head
a_chunked = tf.reshape(a_padded, chunked_shape)

# Concatenate previous and next chunk to each chunk, for overlapping.
if concat_3_chunks:
paddings = tf.one_hot([2, 2], rank + 1, axis=0, dtype=tf.int32)
# batch, num_heads, 1 + num_chunks + 1, chunk_length, units_per_head
a_chunked_padded = tf.pad(a_chunked, paddings)
# batch, num_heads, num_chunks, chunk_length*3, units_per_head
a_chunked = tf.concat(
[a_chunked_padded[:, :, i : (i + num_chunks), ...] for i in range(3)], 3
)

# Transpose and flatten first dimension (batch * num_chunks).
# batch, num_chunks, num_heads, chunk_length (*3), units_per_head
a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4])
input_shape = misc.shape_list(a_transposed)
output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0)
# batch x num_chunks, num_heads, chunk_length (*3), units_per_head
return tf.reshape(a_transposed, output_shape), num_chunks


def chunk_att_mask(mask, chunk_length):
"""Transforms an attention mask into a chunked representation.

Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``.

Args:
mask: A ``tf.Tensor`` of shape :math:`[B, T]` or :math:`[B, T, T]`.
chunk_length: The length of a chunk :math:`C`.

Returns:
A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks.
"""

mask_shape = misc.shape_list(mask)
batch = mask_shape[0]
timesteps = mask_shape[-1]
rank = len(mask_shape)

if rank == 2:
# Broadcast on queries time dimension.
mask = tf.expand_dims(mask, 1)
mask = tf.broadcast_to(mask, [batch, timesteps, timesteps])
rank = 3

# Pad to a factor of chunk_length.
pad_len = -timesteps % chunk_length
mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]])
padded_timesteps = misc.shape_list(mask)[-1]

# Append chunk_length padding to timestep axis, before and after.
paddings = chunk_length * tf.one_hot(
[rank - 1, rank - 1], rank, axis=0, dtype=tf.int32
)
mask_padded = tf.pad(tensor=mask, paddings=paddings)
padded_len = misc.shape_list(mask_padded)[-1]
mask_flattened = tf.reshape(mask_padded, shape=[batch, -1])

# Skew to the left by one and keep 2*chunk_length + 1 relevant locations.
# This corresponds to chunk_length radius around the diagonal.
skewed_len = padded_len + 1
skewed_padding_len = (
padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1]
)
skewed_paddings = skewed_padding_len * tf.one_hot(
[-1, rank - 2], rank - 1, axis=0, dtype=tf.int32
)
mask_padded = tf.pad(mask_flattened, paddings=skewed_paddings)
skewed_shape = [batch, -1, skewed_len]
mask_skewed = tf.reshape(mask_padded, shape=skewed_shape)
mask_skewed = mask_skewed[:, :, : chunk_length * 2 + 1]

chunk_num = padded_timesteps // chunk_length
mask_skewed_chunked = tf.reshape(mask_skewed, [batch, chunk_num, chunk_length, -1])

# Unskew each chunk to be compatible with chunked attention shape.
unskewed_len = chunk_length * 3
unskewed_paddings = chunk_length * tf.one_hot(
[-1, rank], rank + 1, axis=0, dtype=tf.int32
)
mask_skewed_padded = tf.pad(mask_skewed_chunked, paddings=unskewed_paddings)
mask_skewed_flattened = tf.reshape(mask_skewed_padded, shape=[batch, chunk_num, -1])
mask_skewed_flattened = mask_skewed_flattened[:, :, : (chunk_length * unskewed_len)]
mask_unskewed = tf.reshape(
mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3]
)

# Flatten the first dimension to batch * chunk_num.
return tf.reshape(
mask_unskewed, shape=[batch * chunk_num, chunk_length, chunk_length * 3]
)


class FeedForwardNetwork(tf.keras.layers.Layer):
"""Implements the Transformer's "Feed Forward" layer.

Expand Down Expand Up @@ -214,6 +333,8 @@ def __init__(
dropout=0.1,
return_attention=False,
maximum_relative_position=None,
max_length_full_attention=None,
local_attention_radius=None,
**kwargs
):
"""Initializes this layer.
Expand All @@ -225,6 +346,9 @@ def __init__(
return_attention: If ``True``, also return the attention weights.
maximum_relative_position: Maximum relative position representation
(from https://arxiv.org/abs/1803.02155).
max_length_full_attention: Maximum sequence length for full attention.
If ``None``, use sparse attention for longer sequences.
local_attention_radius: Attention radius around each token for local sliding attention.
kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
Expand All @@ -242,6 +366,8 @@ def __init__(
self.dropout = dropout
self.return_attention = return_attention
self.maximum_relative_position = maximum_relative_position
self.max_length_full_attention = max_length_full_attention
self.local_attention_radius = local_attention_radius

def map_v1_weights(self, weights):
# V1 used conv1d layers that have a leading dimensions.
Expand Down Expand Up @@ -354,15 +480,38 @@ def _compute_kv(x):

cache = (keys, values)

queries_length = misc.shape_list(queries)[2]

use_sparse_att = False
if self.max_length_full_attention is not None:
if memory is not None:
raise ValueError("Sparse attention only supports self-attention.")
if self.maximum_relative_position is not None:
raise ValueError("Sparse attention doesn't support relative positions.")
use_sparse_att = queries_length > self.max_length_full_attention

chunk_length = self.local_attention_radius
# Dot product attention.
dot = tf.matmul(queries, keys, transpose_b=True)
if use_sparse_att:
# batch x num_chunks, num_heads, chunk_length, units_per_head
queries_chunked, _ = split_chunks(
queries, chunk_length, concat_3_chunks=False
)
# batch x num_chunks, num_heads, chunk_length*3, units_per_head
keys_chunked, _ = split_chunks(keys, chunk_length)
# batch x num_chunks, num_heads, chunk_length, chunk_length*3
dot = tf.matmul(queries_chunked, keys_chunked, transpose_b=True)
else:
dot = tf.matmul(queries, keys, transpose_b=True)
if relative_repr_keys is not None:
dot += matmul_with_relative_representations(
queries, relative_repr_keys, transpose_b=True
)
if mask is not None:
mask = tf.cast(mask, tf.float32)
if mask.shape.rank == 2:
if use_sparse_att:
mask = chunk_att_mask(mask, chunk_length)
elif mask.shape.rank == 2:
mask = tf.expand_dims(mask, 1) # Broadcast on time dimension.
mask = tf.expand_dims(mask, 1) # Broadcast on head dimension.
dot = tf.cast(
Expand All @@ -371,7 +520,42 @@ def _compute_kv(x):
)
attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
drop_attn = common.dropout(attn, self.dropout, training=training)
heads = tf.matmul(drop_attn, values)
if use_sparse_att:
# batch x num_chunks, num_heads, chunk_length*3, units_per_head
values_chunked, num_chunks = split_chunks(values, chunk_length)
# batch x num_chunks, num_heads, chunk_length, units_per_head
heads = tf.matmul(drop_attn, values_chunked)

# Unchunk
heads_shape = misc.shape_list(heads)
# batch, num_chunks, num_heads, chunk_length, self.num_units_per_head
heads = tf.reshape(
heads,
[
heads_shape[0] // num_chunks,
num_chunks,
heads_shape[1],
heads_shape[2],
heads_shape[3],
],
)
# batch, num_heads, num_chunks, chunk_length, self.num_units_per_head
heads = tf.transpose(heads, perm=[0, 2, 1, 3, 4])
heads_shape = misc.shape_list(heads)
heads = tf.reshape(
heads,
[
heads_shape[0],
heads_shape[1],
heads_shape[2] * heads_shape[3],
heads_shape[4],
],
)

# Remove padding used for chunking.
heads = heads[:, :, :queries_length, :]
else:
heads = tf.matmul(drop_attn, values)
if relative_repr_values is not None:
heads += matmul_with_relative_representations(
drop_attn, relative_repr_values
Expand Down
64 changes: 64 additions & 0 deletions opennmt/tests/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,62 @@ def testRelativePositions(self):
[[2, 3, 4, 4], [1, 2, 3, 4], [0, 1, 2, 3], [0, 0, 1, 2]],
)

@parameterized.expand([[2, True], [2, False], [3, True], [3, False]])
def testSplitChunks(self, chunk_length, concat_3_chunks):
batch = 3
length = [5, 3, 7]
num_heads = 4
depth = 10

inputs = tf.random.normal(
[batch, num_heads, max(length), depth], dtype=tf.float32
)
split, num_chunks = transformer.split_chunks(
inputs, chunk_length=chunk_length, concat_3_chunks=concat_3_chunks
)
split_shape = split.shape
self.assertEqual(num_chunks, split_shape[0] / batch)
self.assertEqual(num_heads, split_shape[1])
chunk_length_eval = chunk_length * 3 if concat_3_chunks else chunk_length
self.assertEqual(chunk_length_eval, split_shape[2])
self.assertEqual(depth, split_shape[3])

@parameterized.expand(
[[tf.bool, 2], [tf.float32, 2], [tf.bool, 3], [tf.float32, 3]]
)
def testChunkAttentionMask(self, dtype, chunk_length):
length = [2, 4, 3]
batch = len(length)
maximum_length = 5
mask = tf.sequence_mask(lengths=length, maxlen=maximum_length, dtype=dtype)
mask_chunked = transformer.chunk_att_mask(mask, chunk_length=chunk_length)
output_shape = mask_chunked.shape
num_chunks = abs(-maximum_length // chunk_length)
self.assertEqual(num_chunks, output_shape[0] / batch)
self.assertEqual(chunk_length, output_shape[1])
self.assertEqual(chunk_length * 3, output_shape[2])

self.assertIs(mask_chunked.dtype, dtype)

expected = np.zeros(output_shape, dtype=dtype.as_numpy_dtype)

token_radius = chunk_length * 2 + 1
for b in range(batch):
seq_length = length[b]
for ch in range(num_chunks):
end = chunk_length + seq_length - chunk_length * ch
if end > 0:
chunk_idx = b * num_chunks + ch
for ch_l in range(chunk_length):
seq_length_idx = ch * chunk_length + ch_l
if seq_length_idx < maximum_length:
start_idx = ch_l if ch != 0 else chunk_length
end_idx = min(end, token_radius + ch_l)
expected[chunk_idx][ch_l][start_idx:end_idx] = 1

mask_chunked = self.evaluate(mask_chunked)
self.assertAllEqual(mask_chunked, expected)

def testFeedForwardNetwork(self):
ffn = transformer.FeedForwardNetwork(20, 10)
x = tf.random.uniform([4, 5, 10])
Expand Down Expand Up @@ -159,6 +215,14 @@ def testMultiHeadSelfAttentionRelativePositionsWithCache(self):
cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5]))
_, cache = attention(x, cache=cache)

def testMultiHeadSelfAttentionSparse(self):
attention = transformer.MultiHeadAttention(
4, 20, local_attention_radius=2, max_length_full_attention=3
)
x = tf.random.uniform([2, 9, 10])
mask = tf.sequence_mask([9, 7])
attention(x, mask=mask)

def testMultiHeadSelfAttentionRelativeGradients(self):
attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6)

Expand Down