Skip to content
90 changes: 80 additions & 10 deletions opennmt/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def matmul_with_relative_representations(a, b, transpose_b=False):
return c


def split_chunks(a, chunk_length, concat_3_chunks=True):
def split_chunks(a, chunk_length, concat_3_chunks=True, global_length=0):
"""Splits a tensor into chunks along the timesteps axis.

Args:
Expand All @@ -129,6 +129,10 @@ def split_chunks(a, chunk_length, concat_3_chunks=True):
A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number.
"""

if global_length:
global_a = a[:, :, :global_length, :]
a = a[:, :, global_length:, :]

batch, num_heads, timesteps, units_per_head = misc.shape_list(a)

# Pad to a factor of chunk_length.
Expand Down Expand Up @@ -157,9 +161,17 @@ def split_chunks(a, chunk_length, concat_3_chunks=True):
# Transpose and flatten first dimension (batch * num_chunks).
# batch, num_chunks, num_heads, chunk_length (*3), units_per_head
a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4])

if global_length:
# batch, num_chunks, num_heads, global timesteps, units_per_head
expanded_global_a = tf.tile(
tf.expand_dims(global_a, 1), [1, num_chunks, 1, 1, 1]
)
a_transposed = tf.concat([a_transposed, expanded_global_a], axis=3)

input_shape = misc.shape_list(a_transposed)
output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0)
# batch x num_chunks, num_heads, chunk_length (*3), units_per_head
# batch x num_chunks, num_heads, chunk_length (*3) + global_length, units_per_head
return tf.reshape(a_transposed, output_shape), num_chunks


Expand Down Expand Up @@ -194,7 +206,7 @@ def combine_chunks(a, num_chunks, unchunked_length):
return a[:, :, :unchunked_length, :]


def chunk_att_mask(mask, chunk_length):
def chunk_att_mask(mask, chunk_length, global_length=0):
"""Transforms an attention mask into a chunked representation.

Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``.
Expand All @@ -207,6 +219,10 @@ def chunk_att_mask(mask, chunk_length):
A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks.
"""

if global_length:
global_mask = mask[:, :global_length]
mask = mask[:, global_length:]

mask_shape = misc.shape_list(mask)
batch = mask_shape[0]
timesteps = mask_shape[-1]
Expand Down Expand Up @@ -254,9 +270,17 @@ def chunk_att_mask(mask, chunk_length):
mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3]
)

if global_length:
# batch, num_chunks, chunk_length, global_length
expanded_global_mask = tf.tile(
global_mask[:, tf.newaxis, tf.newaxis, :], [1, chunk_num, chunk_length, 1]
)
mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3)

# Flatten the first dimension to batch * chunk_num.
return tf.reshape(
mask_unskewed, shape=[batch * chunk_num, chunk_length, chunk_length * 3]
mask_unskewed,
shape=[batch * chunk_num, chunk_length, chunk_length * 3 + global_length],
)


Expand Down Expand Up @@ -359,6 +383,7 @@ def __init__(
maximum_relative_position=None,
max_length_full_attention=None,
local_attention_radius=None,
global_attention_length=0,
**kwargs
):
"""Initializes this layer.
Expand All @@ -374,6 +399,7 @@ def __init__(
If not ``None``, use sparse attention for longer sequences
(from https://arxiv.org/abs/2004.08483).
local_attention_radius: Attention radius around each token for local sliding attention.
global_attention_length: Number of tokens used for global attention with sparse attention.
kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
Expand All @@ -393,6 +419,7 @@ def __init__(
self.maximum_relative_position = maximum_relative_position
self.max_length_full_attention = max_length_full_attention
self.local_attention_radius = local_attention_radius
self.global_attention_length = global_attention_length

def map_v1_weights(self, weights):
# V1 used conv1d layers that have a leading dimensions.
Expand Down Expand Up @@ -513,17 +540,33 @@ def _compute_kv(x):
raise ValueError("Sparse attention only supports self-attention.")
if self.maximum_relative_position is not None:
raise ValueError("Sparse attention doesn't support relative positions.")
if self.return_attention:
raise ValueError(
"Cannot return attention weights when using sparse attention."
)

use_sparse_att = queries_length > self.max_length_full_attention

chunk_length = self.local_attention_radius
# Dot product attention.
if use_sparse_att:
if self.global_attention_length:
global_queries = queries[:, :, : self.global_attention_length, :]
queries = queries[:, :, self.global_attention_length :, :]
global_keys = keys
global_values = values
global_dot = tf.matmul(global_queries, global_keys, transpose_b=True)

# batch x num_chunks, num_heads, chunk_length, units_per_head
queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False)
# batch x num_chunks, num_heads, chunk_length*3, units_per_head
keys, _ = split_chunks(keys, chunk_length)
# batch x num_chunks, num_heads, chunk_length*3, units_per_head
values, num_chunks = split_chunks(values, chunk_length)
# batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head
keys, _ = split_chunks(
keys, chunk_length, global_length=self.global_attention_length
)
# batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head
values, num_chunks = split_chunks(
values, chunk_length, global_length=self.global_attention_length
)
dot = tf.matmul(queries, keys, transpose_b=True)
if relative_repr_keys is not None:
dot += matmul_with_relative_representations(
Expand All @@ -532,19 +575,39 @@ def _compute_kv(x):
if mask is not None:
mask = tf.cast(mask, tf.float32)
if use_sparse_att:
mask = chunk_att_mask(mask, chunk_length)
global_mask = mask[:, tf.newaxis, tf.newaxis, :]
mask = chunk_att_mask(mask, chunk_length, self.global_attention_length)
elif mask.shape.rank == 2:
mask = tf.expand_dims(mask, 1) # Broadcast on time dimension.
mask = tf.expand_dims(mask, 1) # Broadcast on head dimension.
dot = tf.cast(
tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min),
dot.dtype,
)
if self.global_attention_length:
global_dot = tf.cast(
tf.cast(global_dot, tf.float32) * global_mask
+ ((1.0 - global_mask) * tf.float32.min),
global_dot.dtype,
)

attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
drop_attn = common.dropout(attn, self.dropout, training=training)
heads = tf.matmul(drop_attn, values)

if self.global_attention_length:
global_attn = tf.cast(
tf.nn.softmax(tf.cast(global_dot, tf.float32)), global_dot.dtype
)
global_drop_attn = common.dropout(
global_attn, self.dropout, training=training
)
global_heads = tf.matmul(global_drop_attn, global_values)

if use_sparse_att:
heads = combine_chunks(heads, num_chunks, queries_length)
heads = combine_chunks(
heads, num_chunks, queries_length - self.global_attention_length
)
if relative_repr_values is not None:
heads += matmul_with_relative_representations(
drop_attn, relative_repr_values
Expand All @@ -553,6 +616,13 @@ def _compute_kv(x):
# Concatenate all heads output.
combined = combine_heads(heads)
outputs = self.linear_output(combined)
if self.global_attention_length:
global_combined = combine_heads(global_heads)
global_outputs = self.linear_output(
global_combined
) # TODO : a separate global linear input and output layers ?
outputs = tf.concat((global_outputs, outputs), axis=1)

if self.return_attention:
return outputs, cache, attn
return outputs, cache
Expand Down
82 changes: 69 additions & 13 deletions opennmt/tests/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,23 @@ def testRelativePositions(self):
[[2, 3, 4, 4], [1, 2, 3, 4], [0, 1, 2, 3], [0, 0, 1, 2]],
)

@parameterized.expand([[2, True], [2, False], [3, True], [3, False]])
def testSplitChunks(self, chunk_length, concat_3_chunks):
@parameterized.expand(
[
[2, True],
[2, False],
[3, True],
[3, False],
[2, True, 1],
[2, False, 1],
[3, True, 1],
[3, False, 1],
[2, True, 2],
[2, False, 2],
[3, True, 2],
[3, False, 2],
]
)
def testSplitChunks(self, chunk_length, concat_3_chunks, global_length=0):
batch = 3
length = [5, 3, 7]
num_heads = 4
Expand All @@ -127,33 +142,62 @@ def testSplitChunks(self, chunk_length, concat_3_chunks):
[batch, num_heads, max(length), depth], dtype=tf.float32
)
split, num_chunks = transformer.split_chunks(
inputs, chunk_length=chunk_length, concat_3_chunks=concat_3_chunks
inputs,
chunk_length=chunk_length,
concat_3_chunks=concat_3_chunks,
global_length=global_length,
)
split_shape = split.shape
self.assertEqual(num_chunks, split_shape[0] / batch)
self.assertEqual(num_heads, split_shape[1])
chunk_length_eval = chunk_length * 3 if concat_3_chunks else chunk_length
chunk_length_eval += global_length
self.assertEqual(chunk_length_eval, split_shape[2])
self.assertEqual(depth, split_shape[3])

@parameterized.expand(
[[tf.bool, 2], [tf.float32, 2], [tf.bool, 3], [tf.float32, 3]]
[
[tf.bool, 2],
[tf.float32, 2],
[tf.bool, 3],
[tf.float32, 3],
[tf.bool, 2, 1],
[tf.float32, 2, 1],
[tf.bool, 3, 1],
[tf.float32, 3, 1],
[tf.bool, 2, 2],
[tf.float32, 2, 2],
[tf.bool, 3, 2],
[tf.float32, 3, 2],
]
)
def testChunkAttentionMask(self, dtype, chunk_length):
def testChunkAttentionMask(self, dtype, chunk_length, global_length=0):
length = [2, 4, 3]
batch = len(length)
maximum_length = 5
mask = tf.sequence_mask(lengths=length, maxlen=maximum_length, dtype=dtype)
mask_chunked = transformer.chunk_att_mask(mask, chunk_length=chunk_length)
output_shape = mask_chunked.shape
num_chunks = abs(-maximum_length // chunk_length)
self.assertEqual(num_chunks, output_shape[0] / batch)
self.assertEqual(chunk_length, output_shape[1])
self.assertEqual(chunk_length * 3, output_shape[2])
mask_chunked = transformer.chunk_att_mask(
mask, chunk_length=chunk_length, global_length=global_length
)
(
output_batch_times_chunks,
output_chunk_length,
output_expanded_chunk_length,
) = mask_chunked.shape
if global_length:
maximum_length = maximum_length - global_length
length = [el - global_length for el in length]
num_chunks = abs(-(maximum_length) // chunk_length)
self.assertEqual(num_chunks * batch, output_batch_times_chunks)
self.assertEqual(chunk_length, output_chunk_length)
self.assertEqual(chunk_length * 3 + global_length, output_expanded_chunk_length)

self.assertIs(mask_chunked.dtype, dtype)

expected = np.zeros(output_shape, dtype=dtype.as_numpy_dtype)
expected = np.zeros(
(output_batch_times_chunks, output_chunk_length, chunk_length * 3),
dtype=dtype.as_numpy_dtype,
)

token_radius = chunk_length * 2 + 1
for b in range(batch):
Expand All @@ -170,6 +214,14 @@ def testChunkAttentionMask(self, dtype, chunk_length):
expected[chunk_idx][ch_l][start_idx:end_idx] = 1

mask_chunked = self.evaluate(mask_chunked)
if global_length:
expanded_mask = np.repeat(mask, num_chunks, axis=0)
expanded_mask = np.repeat(
expanded_mask[:, np.newaxis, :], chunk_length, axis=1
)
expected = tf.concat(
(expected, expanded_mask[:, :, :global_length]), axis=2
)
self.assertAllEqual(mask_chunked, expected)

def testFeedForwardNetwork(self):
Expand Down Expand Up @@ -217,7 +269,11 @@ def testMultiHeadSelfAttentionRelativePositionsWithCache(self):

def testMultiHeadSelfAttentionSparse(self):
attention = transformer.MultiHeadAttention(
4, 20, local_attention_radius=2, max_length_full_attention=3
4,
20,
local_attention_radius=2,
max_length_full_attention=3,
global_attention_length=2,
)
x = tf.random.uniform([2, 9, 10])
mask = tf.sequence_mask([9, 7])
Expand Down