@@ -109,12 +109,12 @@ class BartSeq2SeqLM(GenerativeTask):
109109 # "The quick brown fox", and the decoder inputs to "The fast". Use
110110 # `"padding_mask"` to indicate values that should not be overridden.
111111 prompt = {
112- "encoder_token_ids": tf.constant ([[0, 133, 2119, 6219, 23602, 2, 1, 1]]),
113- "encoder_padding_mask": tf.constant (
112+ "encoder_token_ids": np.array ([[0, 133, 2119, 6219, 23602, 2, 1, 1]]),
113+ "encoder_padding_mask": np.array (
114114 [[True, True, True, True, True, True, False, False]]
115115 ),
116- "decoder_token_ids": tf.constant ([[2, 0, 133, 1769, 2, 1, 1]]),
117- "decoder_padding_mask": tf.constant ([[True, True, True, True, False, False]])
116+ "decoder_token_ids": np.array ([[2, 0, 133, 1769, 2, 1, 1]]),
117+ "decoder_padding_mask": np.array ([[True, True, True, True, False, False]])
118118 }
119119
120120 bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset(
@@ -137,13 +137,13 @@ class BartSeq2SeqLM(GenerativeTask):
137137 Call `fit()` without preprocessing.
138138 ```python
139139 x = {
140- "encoder_token_ids": tf.constant ([[0, 133, 2119, 2, 1]] * 2),
141- "encoder_padding_mask": tf.constant ([[1, 1, 1, 1, 0]] * 2),
142- "decoder_token_ids": tf.constant ([[2, 0, 133, 1769, 2]] * 2),
143- "decoder_padding_mask": tf.constant ([[1, 1, 1, 1, 1]] * 2),
140+ "encoder_token_ids": np.array ([[0, 133, 2119, 2, 1]] * 2),
141+ "encoder_padding_mask": np.array ([[1, 1, 1, 1, 0]] * 2),
142+ "decoder_token_ids": np.array ([[2, 0, 133, 1769, 2]] * 2),
143+ "decoder_padding_mask": np.array ([[1, 1, 1, 1, 1]] * 2),
144144 }
145- y = tf.constant ([[0, 133, 1769, 2, 1]] * 2)
146- sw = tf.constant ([[1, 1, 1, 1, 0]] * 2)
145+ y = np.array ([[0, 133, 1769, 2, 1]] * 2)
146+ sw = np.array ([[1, 1, 1, 1, 0]] * 2)
147147
148148 bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset(
149149 "bart_base_en",
0 commit comments