Skip to content
19 changes: 19 additions & 0 deletions bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __gin__ import dynamic_registration

from t5x import models
import seqio

include "bigscience/gins/nc_dec_xxl.gin"
include "t5x/configs/runs/pretrain.gin"
include "bigscience/gins/pretrainer_base.gin"

TASK_FEATURE_LENGTHS = {
"decoder_target_tokens": 626,
"decoder_input_tokens": 626,
"decoder_segment_ids": 626,
"decoder_causal_attention": 626,
"targets": 625 # we have to take in account an extra token between input and target
}
MIXTURE_OR_TASK_NAME = "c4_prefix_lm_objective_decoder_architecture_with_bot_seperator"

models.DecoderOnlyModel.feature_converter_cls = @seqio.PassThroughFeatureConverter
105 changes: 105 additions & 0 deletions bigscience/gins/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import dataclasses
import functools

import seqio
import t5
import tensorflow as tf
from t5.data import preprocessors, get_default_vocabulary
from t5.data.preprocessors import select_random_chunk, reduce_concat_tokens, split_tokens

Expand Down Expand Up @@ -43,6 +46,108 @@ def full_lm(dataset, sequence_length, output_features):
},
metric_fns=[])

# We want input and target to have an additional token between.
# Inspired by https://github.com/google-research/text-to-text-transfer-transformer/blob/9844ddb4f760ae8a1d4de410578f6211e487bbf9/t5/data/tasks.py#L445

assert get_default_vocabulary().vocab_size == 32100, "Use T5 tokenizer by default"
BOT_ID = 32000 # FIXME: this is only true for t5 tokenizer right now.
@dataclasses.dataclass(frozen=True)
class FancyFeature(seqio.Feature):
# This token is use to seperate input and target. `bot` is the acronym for beginning of target
add_bot: bool = False

def pack_prefix_lm_decoder_only(ds,
sequence_length,
output_features,
loss_on_targets_only=True,
pad_id=0):
"""Randomly split the tokens for the prefix LM objective."""
packed_length = sequence_length["decoder_input_tokens"]
assert packed_length % 2 == 0
# "targets" is a special key
add_bot = output_features["decoder_input_tokens"].add_bot

assert all(l == packed_length for key, l in sequence_length.items() if (not add_bot) or key != "targets")
assert all(l.add_bot == add_bot for key, l in output_features.items() if key != "targets")
if add_bot:
assert sequence_length["targets"] == packed_length - 1
else:
assert sequence_length["targets"] == packed_length

@seqio.utils.map_over_dataset(num_seeds=1)
def pack_examples(example, seed):
split_point = tf.random.stateless_uniform((),
minval=1,
# Adding an extra token costs a bit.
maxval=packed_length if add_bot else packed_length - 1,
seed=seed,
dtype=tf.int32)
if add_bot:
decoder_target_tokens = tf.concat(
[
example['targets'][:split_point - 1],
# bot will be the same as _<extra_id_99>. Not ideal, but the tokenizer doesn't have `bos` right now.
[BOT_ID],
example['targets'][split_point - 1:],
],
axis=0
)
# This has to be specified otherwise dataset tensor spec assigns None in shape.
decoder_target_tokens = tf.reshape(decoder_target_tokens, (packed_length,))
else:
decoder_target_tokens = example['targets']

decoder_input_tokens = seqio.utils.make_autoregressive_inputs(decoder_target_tokens)

if loss_on_targets_only:
decoder_loss_weights = tf.cast(
tf.range(packed_length) >= split_point, tf.int32)
else:
decoder_loss_weights = tf.ones((packed_length,), dtype=tf.int32)

padding_mask = tf.cast(
tf.not_equal(decoder_target_tokens, pad_id), dtype=tf.int32)
decoder_loss_weights *= padding_mask

decoder_causal_attention = tf.cast(
tf.range(packed_length) <= split_point, tf.int32)

return {
'decoder_target_tokens': decoder_target_tokens,
'decoder_input_tokens': decoder_input_tokens,
'decoder_loss_weights': decoder_loss_weights,
'decoder_causal_attention': decoder_causal_attention,
}

return pack_examples(ds)

TaskRegistry.add(
"c4_prefix_lm_objective_decoder_architecture_with_bot_seperator",
source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"),
preprocessors=[
functools.partial(
preprocessors.rekey, key_map={
"inputs": None,
"targets": "text"
}),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
t5.data.preprocessors.targets_for_prefix_lm_objective,
pack_prefix_lm_decoder_only,
],
output_features={
"decoder_target_tokens": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True),
"decoder_input_tokens": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True),
"decoder_loss_weights": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True),
"decoder_causal_attention": FancyFeature(
vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True),
# All but the last stage of the preprocessing uses "targets" as the key,
# so this output feature is necessary. It is not marked required because
# the final preprocessor drops it.
"targets": seqio.Feature(vocabulary=get_default_vocabulary(), required=False),
},
metric_fns=[])

# --- Improve sharding ---

# def fully_sharded_logical_axis_rules() -> LogicalAxisRules:
Expand Down
1 change: 1 addition & 0 deletions bigscience/scripts/setup_vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ popd
#rm -rf t5x
git clone https://github.com/bigscience-workshop/t5x.git
pushd t5x
git checkout thomas/prefix_lm_add_token
pip3 install -e .
popd

Expand Down
23 changes: 17 additions & 6 deletions bigscience/scripts/test_seqio_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from t5x import models
import seqio

from t5x import utils
import tensorflow as tf
from ..gins import task

def main():
ds = utils.get_dataset(
utils.DatasetConfig(
"c4_v220_full_lm",
task_feature_lengths={"targets": 626},
"c4_prefix_lm_objective_decoder_architecture_with_bot_seperator",

task_feature_lengths={
"decoder_target_tokens": 626,
"decoder_input_tokens": 626,
"decoder_segment_ids": 626,
"decoder_causal_attention": 626,
"targets": 625 # we have to take in account an extra token between input and target
},
split="train",
batch_size=2048,
shuffle=False,
Expand All @@ -19,7 +27,7 @@ def main():
),
0,
1,
models.DecoderOnlyModel.FEATURE_CONVERTER_CLS
seqio.PassThroughFeatureConverter,
)
first_element = next(iter(ds))
print(first_element)
Expand All @@ -34,8 +42,11 @@ def main():
print(tf.shape(first_element["decoder_target_tokens"]))
print(tf.shape(first_element["decoder_input_tokens"]))
print(tf.shape(first_element["decoder_loss_weights"]))
print(tf.shape(first_element["decoder_segment_ids"]))
print(tf.shape(first_element["decoder_positions"]))
# print(tf.shape(first_element["decoder_segment_ids"]))
# print(tf.shape(first_element["decoder_positions"]))
print(tf.where(first_element["decoder_target_tokens"] == 32000))
print(tf.where(first_element["decoder_input_tokens"] == 32000))
print(ds.element_spec)

if __name__ == "__main__":
main()