diff --git a/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin b/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin new file mode 100644 index 000000000..fe8e334b8 --- /dev/null +++ b/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin @@ -0,0 +1,19 @@ +from __gin__ import dynamic_registration + +from t5x import models +import seqio + +include "bigscience/gins/nc_dec_xxl.gin" +include "t5x/configs/runs/pretrain.gin" +include "bigscience/gins/pretrainer_base.gin" + +TASK_FEATURE_LENGTHS = { + "decoder_target_tokens": 626, + "decoder_input_tokens": 626, + "decoder_segment_ids": 626, + "decoder_causal_attention": 626, + "targets": 625 # we have to take in account an extra token between input and target +} +MIXTURE_OR_TASK_NAME = "c4_prefix_lm_objective_decoder_architecture_with_bot_seperator" + +models.DecoderOnlyModel.feature_converter_cls = @seqio.PassThroughFeatureConverter \ No newline at end of file diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index 7db13c4ba..e585f271c 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -1,6 +1,9 @@ +import dataclasses import functools import seqio +import t5 +import tensorflow as tf from t5.data import preprocessors, get_default_vocabulary from t5.data.preprocessors import select_random_chunk, reduce_concat_tokens, split_tokens @@ -43,6 +46,108 @@ def full_lm(dataset, sequence_length, output_features): }, metric_fns=[]) +# We want input and target to have an additional token between. +# Inspired by https://github.com/google-research/text-to-text-transfer-transformer/blob/9844ddb4f760ae8a1d4de410578f6211e487bbf9/t5/data/tasks.py#L445 + +assert get_default_vocabulary().vocab_size == 32100, "Use T5 tokenizer by default" +BOT_ID = 32000 # FIXME: this is only true for t5 tokenizer right now. +@dataclasses.dataclass(frozen=True) +class FancyFeature(seqio.Feature): + # This token is use to seperate input and target. `bot` is the acronym for beginning of target + add_bot: bool = False + +def pack_prefix_lm_decoder_only(ds, + sequence_length, + output_features, + loss_on_targets_only=True, + pad_id=0): + """Randomly split the tokens for the prefix LM objective.""" + packed_length = sequence_length["decoder_input_tokens"] + assert packed_length % 2 == 0 + # "targets" is a special key + add_bot = output_features["decoder_input_tokens"].add_bot + + assert all(l == packed_length for key, l in sequence_length.items() if (not add_bot) or key != "targets") + assert all(l.add_bot == add_bot for key, l in output_features.items() if key != "targets") + if add_bot: + assert sequence_length["targets"] == packed_length - 1 + else: + assert sequence_length["targets"] == packed_length + + @seqio.utils.map_over_dataset(num_seeds=1) + def pack_examples(example, seed): + split_point = tf.random.stateless_uniform((), + minval=1, + # Adding an extra token costs a bit. + maxval=packed_length if add_bot else packed_length - 1, + seed=seed, + dtype=tf.int32) + if add_bot: + decoder_target_tokens = tf.concat( + [ + example['targets'][:split_point - 1], + # bot will be the same as _. Not ideal, but the tokenizer doesn't have `bos` right now. + [BOT_ID], + example['targets'][split_point - 1:], + ], + axis=0 + ) + # This has to be specified otherwise dataset tensor spec assigns None in shape. + decoder_target_tokens = tf.reshape(decoder_target_tokens, (packed_length,)) + else: + decoder_target_tokens = example['targets'] + + decoder_input_tokens = seqio.utils.make_autoregressive_inputs(decoder_target_tokens) + + if loss_on_targets_only: + decoder_loss_weights = tf.cast( + tf.range(packed_length) >= split_point, tf.int32) + else: + decoder_loss_weights = tf.ones((packed_length,), dtype=tf.int32) + + padding_mask = tf.cast( + tf.not_equal(decoder_target_tokens, pad_id), dtype=tf.int32) + decoder_loss_weights *= padding_mask + + decoder_causal_attention = tf.cast( + tf.range(packed_length) <= split_point, tf.int32) + + return { + 'decoder_target_tokens': decoder_target_tokens, + 'decoder_input_tokens': decoder_input_tokens, + 'decoder_loss_weights': decoder_loss_weights, + 'decoder_causal_attention': decoder_causal_attention, + } + + return pack_examples(ds) + +TaskRegistry.add( + "c4_prefix_lm_objective_decoder_architecture_with_bot_seperator", + source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"), + preprocessors=[ + functools.partial( + preprocessors.rekey, key_map={ + "inputs": None, + "targets": "text" + }), + seqio.preprocessors.tokenize, + seqio.CacheDatasetPlaceholder(), + t5.data.preprocessors.targets_for_prefix_lm_objective, + pack_prefix_lm_decoder_only, + ], + output_features={ + "decoder_target_tokens": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True), + "decoder_input_tokens": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True), + "decoder_loss_weights": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True), + "decoder_causal_attention": FancyFeature( + vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True), + # All but the last stage of the preprocessing uses "targets" as the key, + # so this output feature is necessary. It is not marked required because + # the final preprocessor drops it. + "targets": seqio.Feature(vocabulary=get_default_vocabulary(), required=False), + }, + metric_fns=[]) + # --- Improve sharding --- # def fully_sharded_logical_axis_rules() -> LogicalAxisRules: diff --git a/bigscience/scripts/setup_vm.sh b/bigscience/scripts/setup_vm.sh index 2b993b1d6..8b56d9a82 100644 --- a/bigscience/scripts/setup_vm.sh +++ b/bigscience/scripts/setup_vm.sh @@ -35,6 +35,7 @@ popd #rm -rf t5x git clone https://github.com/bigscience-workshop/t5x.git pushd t5x +git checkout thomas/prefix_lm_add_token pip3 install -e . popd diff --git a/bigscience/scripts/test_seqio_dataset.py b/bigscience/scripts/test_seqio_dataset.py index 6e66b6f69..87a3ba554 100644 --- a/bigscience/scripts/test_seqio_dataset.py +++ b/bigscience/scripts/test_seqio_dataset.py @@ -1,4 +1,5 @@ -from t5x import models +import seqio + from t5x import utils import tensorflow as tf from ..gins import task @@ -6,8 +7,15 @@ def main(): ds = utils.get_dataset( utils.DatasetConfig( - "c4_v220_full_lm", - task_feature_lengths={"targets": 626}, + "c4_prefix_lm_objective_decoder_architecture_with_bot_seperator", + + task_feature_lengths={ + "decoder_target_tokens": 626, + "decoder_input_tokens": 626, + "decoder_segment_ids": 626, + "decoder_causal_attention": 626, + "targets": 625 # we have to take in account an extra token between input and target + }, split="train", batch_size=2048, shuffle=False, @@ -19,7 +27,7 @@ def main(): ), 0, 1, - models.DecoderOnlyModel.FEATURE_CONVERTER_CLS + seqio.PassThroughFeatureConverter, ) first_element = next(iter(ds)) print(first_element) @@ -34,8 +42,11 @@ def main(): print(tf.shape(first_element["decoder_target_tokens"])) print(tf.shape(first_element["decoder_input_tokens"])) print(tf.shape(first_element["decoder_loss_weights"])) - print(tf.shape(first_element["decoder_segment_ids"])) - print(tf.shape(first_element["decoder_positions"])) + # print(tf.shape(first_element["decoder_segment_ids"])) + # print(tf.shape(first_element["decoder_positions"])) + print(tf.where(first_element["decoder_target_tokens"] == 32000)) + print(tf.where(first_element["decoder_input_tokens"] == 32000)) + print(ds.element_spec) if __name__ == "__main__": main() \ No newline at end of file