diff --git a/bigscience/gins/enc_dec_c4_span_corruption_normalised_loss.gin b/bigscience/gins/enc_dec_c4_span_corruption_normalised_loss.gin new file mode 100644 index 000000000..cc3d5e914 --- /dev/null +++ b/bigscience/gins/enc_dec_c4_span_corruption_normalised_loss.gin @@ -0,0 +1,152 @@ +from __gin__ import dynamic_registration +import __main__ as train_script +import seqio +from t5x import adafactor +from t5x.examples.t5 import network +from t5x import gin_utils +from t5x import models +from t5x import partitioning +from t5x import trainer +from t5x import utils +import task + +# Macros: +# ============================================================================== +DROPOUT_RATE = 0.0 +LABEL_SMOOTHING = 0.0 +LOSS_NORMALIZING_FACTOR = 233472 # 2048 * 114 +MIXTURE_OR_TASK_MODULE = 't5.data.mixtures' +MIXTURE_OR_TASK_NAME = 'c4_v220_span_corruption' +MODEL = @models.EncoderDecoderModel() +MODEL_DIR = 'gs://bigscience-t5x/enc_dec_c4_span_corruption_normalised_loss' +OPTIMIZER = @adafactor.Adafactor() +RANDOM_SEED = None +SHUFFLE_TRAIN_EXAMPLES = True +TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 114} +TRAIN_STEPS = 524288 +USE_CACHED_TASKS = True +USE_HARDWARE_RNG = False +VOCABULARY = @seqio.SentencePieceVocabulary() +Z_LOSS = 0.0001 + +# Parameters for adafactor.Adafactor: +# ============================================================================== +adafactor.Adafactor.decay_rate = 0.8 +adafactor.Adafactor.logical_factor_rules = \ + @adafactor.standard_logical_factor_rules() +adafactor.Adafactor.step_offset = 0 + +# Parameters for utils.CheckpointConfig: +# ============================================================================== +utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig() +utils.CheckpointConfig.save = @utils.SaveCheckpointConfig() + +# Parameters for utils.create_learning_rate_scheduler: +# ============================================================================== +utils.create_learning_rate_scheduler.base_learning_rate = 1.0 +utils.create_learning_rate_scheduler.factors = 'constant * rsqrt_decay' +utils.create_learning_rate_scheduler.warmup_steps = 10000 + +# Parameters for train/utils.DatasetConfig: +# ============================================================================== +train/utils.DatasetConfig.batch_size = 2048 +train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME +train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE +train/utils.DatasetConfig.pack = True +train/utils.DatasetConfig.seed = None +train/utils.DatasetConfig.shuffle = %SHUFFLE_TRAIN_EXAMPLES +train/utils.DatasetConfig.split = 'train' +train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS +train/utils.DatasetConfig.use_cached = True +train/utils.DatasetConfig.use_custom_packing_ops = False + +# Parameters for train_eval/utils.DatasetConfig: +# ============================================================================== +train_eval/utils.DatasetConfig.batch_size = 2048 +train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME +train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE +train_eval/utils.DatasetConfig.pack = True +train_eval/utils.DatasetConfig.seed = 42 +train_eval/utils.DatasetConfig.shuffle = False +train_eval/utils.DatasetConfig.split = 'validation' +train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS +train_eval/utils.DatasetConfig.use_cached = True +train_eval/utils.DatasetConfig.use_custom_packing_ops = False + +# Parameters for models.EncoderDecoderModel: +# ============================================================================== +models.EncoderDecoderModel.input_vocabulary = %VOCABULARY +models.EncoderDecoderModel.module = @network.Transformer() +models.EncoderDecoderModel.optimizer_def = %OPTIMIZER +models.EncoderDecoderModel.output_vocabulary = %VOCABULARY + +# Parameters for models.EncoderDecoderModel.loss_fn: +# ============================================================================== +models.EncoderDecoderModel.loss_fn.label_smoothing = %LABEL_SMOOTHING +models.EncoderDecoderModel.loss_fn.loss_normalizing_factor = \ + %LOSS_NORMALIZING_FACTOR +models.EncoderDecoderModel.loss_fn.z_loss = %Z_LOSS + +# Parameters for partitioning.ModelBasedPjitPartitioner: +# ============================================================================== +partitioning.ModelBasedPjitPartitioner.logical_axis_rules = \ + @task.fully_sharded_logical_axis_rules() +partitioning.ModelBasedPjitPartitioner.model_parallel_submesh = (4, 1, 1, 1) +partitioning.ModelBasedPjitPartitioner.num_partitions = 1 + +# Parameters for utils.RestoreCheckpointConfig: +# ============================================================================== +utils.RestoreCheckpointConfig.path = [] + +# Parameters for utils.SaveCheckpointConfig: +# ============================================================================== +utils.SaveCheckpointConfig.dtype = 'float32' +utils.SaveCheckpointConfig.keep = None +utils.SaveCheckpointConfig.period = 2000 +utils.SaveCheckpointConfig.save_dataset = False + +# Parameters for seqio.SentencePieceVocabulary: +# ============================================================================== +seqio.SentencePieceVocabulary.sentencepiece_model_file = \ + 'gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model' + +# Parameters for network.T5Config: +# ============================================================================== +network.T5Config.dropout_rate = %DROPOUT_RATE +network.T5Config.dtype = 'bfloat16' +network.T5Config.emb_dim = 4096 +network.T5Config.head_dim = 64 +network.T5Config.logits_via_embedding = True +network.T5Config.mlp_activations = ('gelu', 'linear') +network.T5Config.mlp_dim = 10240 +network.T5Config.num_decoder_layers = 24 +network.T5Config.num_encoder_layers = 24 +network.T5Config.num_heads = 64 +network.T5Config.vocab_size = 32128 + +# Parameters for train_script.train: +# ============================================================================== +train_script.train.checkpoint_cfg = @utils.CheckpointConfig() +train_script.train.eval_period = 1000 +train_script.train.eval_steps = 100 +train_script.train.infer_eval_dataset_cfg = None +train_script.train.model = %MODEL +train_script.train.model_dir = %MODEL_DIR +train_script.train.partitioner = @partitioning.ModelBasedPjitPartitioner() +train_script.train.random_seed = None +train_script.train.stats_period = 200 +train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config +train_script.train.total_steps = %TRAIN_STEPS +train_script.train.train_dataset_cfg = @train/utils.DatasetConfig() +train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() +train_script.train.trainer_cls = @trainer.Trainer +train_script.train.use_hardware_rng = %USE_HARDWARE_RNG + +# Parameters for trainer.Trainer: +# ============================================================================== +trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler() +trainer.Trainer.num_microbatches = None + +# Parameters for network.Transformer: +# ============================================================================== +network.Transformer.config = @network.T5Config() diff --git a/bigscience/scripts/setup_vm.sh b/bigscience/scripts/setup_vm.sh index 0032724b6..70e033453 100644 --- a/bigscience/scripts/setup_vm.sh +++ b/bigscience/scripts/setup_vm.sh @@ -29,6 +29,7 @@ popd #rm -rf t5x git clone https://github.com/bigscience-workshop/t5x.git pushd t5x +git checkout thomas/test_loss_normalisation pip3 install -e . popd