@@ -26,11 +26,8 @@ class StartEndPacker(keras.layers.Layer):
2626 be called after tokenization. The layer will first trim inputs to fit, then
2727 add start/end tokens, and finally pad, if necessary, to `sequence_length`.
2828
29- If input is batched, input should be a `tf.RaggedTensor` with shape
30- `[batch_size, None]` and will be packed and converted to a dense tensor with
31- shape `[batch_size, sequence_length]`.
32- If input is unbatched, input should be a dense rank-1 tensor of any shape,
33- and will be packed to shape `[sequence_length]`.
29+ Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and
30+ either rank-1 or rank-2.
3431
3532 Args:
3633 sequence_length: int. The desired output length.
@@ -108,26 +105,19 @@ def call(self, inputs):
108105 if not isinstance (inputs , (tf .Tensor , tf .RaggedTensor )):
109106 inputs = tf .convert_to_tensor (inputs )
110107
111- input_is_dense = isinstance (inputs , tf .Tensor )
112- input_is_ragged = isinstance (inputs , tf .RaggedTensor )
113-
114- if input_is_dense :
115- if inputs .shape .rank != 1 :
116- raise ValueError (
117- "Input must either be dense with rank 1, or ragged with "
118- "rank 2. Received dense input with "
119- f"rank={ inputs .shape .rank } "
120- )
121-
122- # Add a new axis at the beginning and convert to ragged tensor.
123- inputs = tf .RaggedTensor .from_tensor (tf .expand_dims (inputs , axis = 0 ))
124- elif input_is_ragged :
125- if inputs .shape .rank != 2 :
126- raise ValueError (
127- "Input must either be dense with rank 1, or ragged with "
128- "rank 2. Received ragged input with "
129- f"rank={ inputs .shape .rank } "
130- )
108+ input_is_1d = False
109+ if inputs .shape .rank < 1 or inputs .shape .rank > 2 :
110+ raise ValueError (
111+ "Input must either be rank 1 or rank 2. Received input with "
112+ f"rank={ inputs .shape .rank } "
113+ )
114+ elif inputs .shape .rank == 1 :
115+ input_is_1d = True
116+ # Add a new axis at the beginning.
117+ inputs = tf .expand_dims (inputs , axis = 0 )
118+ if isinstance (inputs , tf .Tensor ):
119+ # Convert to ragged tensor.
120+ inputs = tf .RaggedTensor .from_tensor (inputs )
131121
132122 batch_size = tf .shape (inputs )[0 ]
133123
@@ -148,7 +138,7 @@ def call(self, inputs):
148138 shape = (batch_size , self .sequence_length ),
149139 )
150140
151- if input_is_dense :
141+ if input_is_1d :
152142 inputs = tf .squeeze (inputs , axis = 0 )
153143
154144 return inputs
0 commit comments