Skip to content

Commit cc0209a

Browse files
authored
MultiSegmentPacker support for 2D dense tensor (#244)
* Support for 2D dense tensor * fixes * style fixes
1 parent bd92aca commit cc0209a

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

keras_nlp/layers/multi_segment_packer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,8 @@ class MultiSegmentPacker(keras.layers.Layer):
4040
is always 0, and the segment id of each `end_value` is the segment that
4141
precedes it.
4242
43-
If inputs are batched, inputs should be `tf.RaggedTensor`s with shape
44-
`[batch_size, None]` and will be packed and converted to a dense tensor with
45-
shape `[batch_size, sequence_length]`.
46-
47-
If inputs are unbatched, inputs should be dense rank-1 tensors of any shape,
48-
and will be packed to shape `[sequence_length]`.
43+
Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and
44+
either rank-1 or rank-2.
4945
5046
Args:
5147
sequence_length: The desired output length.
@@ -155,6 +151,13 @@ def _sanitize_inputs(self, inputs):
155151
)
156152
return inputs
157153

154+
def _convert_dense(self, x):
155+
"""Converts inputs to rank 2 ragged tensors."""
156+
if isinstance(x, tf.Tensor):
157+
return tf.RaggedTensor.from_tensor(x)
158+
else:
159+
return x
160+
158161
def _trim_inputs(self, inputs):
159162
"""Trim inputs to desired length."""
160163
num_special_tokens = len(inputs) + 1
@@ -199,22 +202,20 @@ def _combine_inputs(self, segments):
199202
def call(self, inputs):
200203
inputs = self._sanitize_inputs(inputs)
201204

202-
# If rank 1, add a batch dim and convert to ragged.
205+
# If rank 1, add a batch dim.
203206
rank_1 = inputs[0].shape.rank == 1
204207
if rank_1:
205208
inputs = [tf.expand_dims(x, 0) for x in inputs]
206-
inputs = [tf.RaggedTensor.from_tensor(x) for x in inputs]
209+
inputs = [self._convert_dense(x) for x in inputs]
207210

208211
segments = self._trim_inputs(inputs)
209212
token_ids, segment_ids = self._combine_inputs(segments)
210-
211213
# Pad to dense tensor output.
212214
shape = tf.cast([-1, self.sequence_length], "int64")
213215
token_ids = token_ids.to_tensor(
214216
shape=shape, default_value=self.pad_value
215217
)
216218
segment_ids = segment_ids.to_tensor(shape=shape)
217-
218219
# Remove the batch dim if added.
219220
if rank_1:
220221
token_ids = tf.squeeze(token_ids, 0)

keras_nlp/layers/multi_segment_packer_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def test_trim_multiple_inputs_waterfall(self):
6767
)
6868

6969
def test_trim_batched_inputs_round_robin(self):
70-
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b", "c"]])
71-
seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]])
70+
seq1 = tf.constant([["a", "b", "c"], ["a", "b", "c"]])
71+
seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]])
7272
packer = MultiSegmentPacker(
7373
7, start_value="[CLS]", end_value="[SEP]", truncator="round_robin"
7474
)
@@ -89,7 +89,7 @@ def test_trim_batched_inputs_round_robin(self):
8989

9090
def test_trim_batched_inputs_waterfall(self):
9191
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]])
92-
seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]])
92+
seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]])
9393
packer = MultiSegmentPacker(
9494
7, start_value="[CLS]", end_value="[SEP]", truncator="waterfall"
9595
)

0 commit comments

Comments
 (0)