Skip to content

Commit b5e0ef2

Browse files
authored
Add support for StartEndPacker packing 2D tensor (#240)
* Added support for packing 2D tensor. * minor edit to test * minor fixes * minor name changes * input shape check fixes and docstring simplification * minor changes
1 parent 0c99503 commit b5e0ef2

File tree

2 files changed

+21
-35
lines changed

2 files changed

+21
-35
lines changed

keras_nlp/layers/start_end_packer.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@ class StartEndPacker(keras.layers.Layer):
2626
be called after tokenization. The layer will first trim inputs to fit, then
2727
add start/end tokens, and finally pad, if necessary, to `sequence_length`.
2828
29-
If input is batched, input should be a `tf.RaggedTensor` with shape
30-
`[batch_size, None]` and will be packed and converted to a dense tensor with
31-
shape `[batch_size, sequence_length]`.
32-
If input is unbatched, input should be a dense rank-1 tensor of any shape,
33-
and will be packed to shape `[sequence_length]`.
29+
Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and
30+
either rank-1 or rank-2.
3431
3532
Args:
3633
sequence_length: int. The desired output length.
@@ -108,26 +105,19 @@ def call(self, inputs):
108105
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
109106
inputs = tf.convert_to_tensor(inputs)
110107

111-
input_is_dense = isinstance(inputs, tf.Tensor)
112-
input_is_ragged = isinstance(inputs, tf.RaggedTensor)
113-
114-
if input_is_dense:
115-
if inputs.shape.rank != 1:
116-
raise ValueError(
117-
"Input must either be dense with rank 1, or ragged with "
118-
"rank 2. Received dense input with "
119-
f"rank={inputs.shape.rank}"
120-
)
121-
122-
# Add a new axis at the beginning and convert to ragged tensor.
123-
inputs = tf.RaggedTensor.from_tensor(tf.expand_dims(inputs, axis=0))
124-
elif input_is_ragged:
125-
if inputs.shape.rank != 2:
126-
raise ValueError(
127-
"Input must either be dense with rank 1, or ragged with "
128-
"rank 2. Received ragged input with "
129-
f"rank={inputs.shape.rank}"
130-
)
108+
input_is_1d = False
109+
if inputs.shape.rank < 1 or inputs.shape.rank > 2:
110+
raise ValueError(
111+
"Input must either be rank 1 or rank 2. Received input with "
112+
f"rank={inputs.shape.rank}"
113+
)
114+
elif inputs.shape.rank == 1:
115+
input_is_1d = True
116+
# Add a new axis at the beginning.
117+
inputs = tf.expand_dims(inputs, axis=0)
118+
if isinstance(inputs, tf.Tensor):
119+
# Convert to ragged tensor.
120+
inputs = tf.RaggedTensor.from_tensor(inputs)
131121

132122
batch_size = tf.shape(inputs)[0]
133123

@@ -148,7 +138,7 @@ def call(self, inputs):
148138
shape=(batch_size, self.sequence_length),
149139
)
150140

151-
if input_is_dense:
141+
if input_is_1d:
152142
inputs = tf.squeeze(inputs, axis=0)
153143

154144
return inputs

keras_nlp/layers/start_end_packer_test.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,12 @@ def test_dense_input(self):
2929
expected_output = [5, 6, 7, 0, 0]
3030
self.assertAllEqual(output, expected_output)
3131

32-
def test_dense_input_error(self):
32+
def test_dense_2D_input(self):
3333
input_data = tf.constant([[5, 6, 7]])
3434
start_end_packer = StartEndPacker(sequence_length=5)
35-
with self.assertRaisesRegex(
36-
ValueError,
37-
"Input must either be dense with rank 1, or ragged with rank 2. "
38-
"Received dense input with rank=2",
39-
):
40-
start_end_packer(input_data)
35+
output = start_end_packer(input_data)
36+
expected_output = [[5, 6, 7, 0, 0]]
37+
self.assertAllEqual(output, expected_output)
4138

4239
def test_ragged_input(self):
4340
input_data = tf.ragged.constant([[5, 6, 7], [8, 9, 10, 11]])
@@ -51,8 +48,7 @@ def test_ragged_input_error(self):
5148
start_end_packer = StartEndPacker(sequence_length=5)
5249
with self.assertRaisesRegex(
5350
ValueError,
54-
"Input must either be dense with rank 1, or ragged with rank 2. "
55-
"Received ragged input with "
51+
"Input must either be rank 1 or rank 2. Received input with "
5652
"rank=3",
5753
):
5854
start_end_packer(input_data)

0 commit comments

Comments
 (0)