@@ -40,12 +40,8 @@ class MultiSegmentPacker(keras.layers.Layer):
4040 is always 0, and the segment id of each `end_value` is the segment that
4141 precedes it.
4242
43- If inputs are batched, inputs should be `tf.RaggedTensor`s with shape
44- `[batch_size, None]` and will be packed and converted to a dense tensor with
45- shape `[batch_size, sequence_length]`.
46-
47- If inputs are unbatched, inputs should be dense rank-1 tensors of any shape,
48- and will be packed to shape `[sequence_length]`.
43+ Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and
44+ either rank-1 or rank-2.
4945
5046 Args:
5147 sequence_length: The desired output length.
@@ -155,6 +151,13 @@ def _sanitize_inputs(self, inputs):
155151 )
156152 return inputs
157153
154+ def _convert_dense (self , x ):
155+ """Converts inputs to rank 2 ragged tensors."""
156+ if isinstance (x , tf .Tensor ):
157+ return tf .RaggedTensor .from_tensor (x )
158+ else :
159+ return x
160+
158161 def _trim_inputs (self , inputs ):
159162 """Trim inputs to desired length."""
160163 num_special_tokens = len (inputs ) + 1
@@ -199,22 +202,20 @@ def _combine_inputs(self, segments):
199202 def call (self , inputs ):
200203 inputs = self ._sanitize_inputs (inputs )
201204
202- # If rank 1, add a batch dim and convert to ragged .
205+ # If rank 1, add a batch dim.
203206 rank_1 = inputs [0 ].shape .rank == 1
204207 if rank_1 :
205208 inputs = [tf .expand_dims (x , 0 ) for x in inputs ]
206- inputs = [tf . RaggedTensor . from_tensor (x ) for x in inputs ]
209+ inputs = [self . _convert_dense (x ) for x in inputs ]
207210
208211 segments = self ._trim_inputs (inputs )
209212 token_ids , segment_ids = self ._combine_inputs (segments )
210-
211213 # Pad to dense tensor output.
212214 shape = tf .cast ([- 1 , self .sequence_length ], "int64" )
213215 token_ids = token_ids .to_tensor (
214216 shape = shape , default_value = self .pad_value
215217 )
216218 segment_ids = segment_ids .to_tensor (shape = shape )
217-
218219 # Remove the batch dim if added.
219220 if rank_1 :
220221 token_ids = tf .squeeze (token_ids , 0 )
0 commit comments