From ed0cb6ef848eb8e350de4979e8aea36e7413fd62 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Wed, 22 Dec 2021 15:47:46 +0100 Subject: [PATCH 1/4] Add HuggingFaceDatasetsSource --- bigscience/gins/task.py | 100 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index 1870f9e24..80049cb57 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -1,8 +1,13 @@ import functools +from typing import Optional, Sequence +import datasets import seqio +import t5 +from seqio.dataset_providers import _validate_args, ShardInfo from t5.data import preprocessors, get_default_vocabulary from t5.data.preprocessors import select_random_chunk, reduce_concat_tokens, split_tokens +import tensorflow as tf from promptsource import seqio_tasks @@ -168,6 +173,101 @@ def full_lm(dataset, sequence_length, output_features): default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), ) # eval mixture does not need to be capped + +class HuggingFaceDatasetsSource(seqio.DataSource): + def __init__( + self, + dataset_name: str, + subset_name: str, + num_shards: int, + caching_permitted: bool = True + ): + """HuggingFaceDatasetsSource constructor. + Args: + dataset_name: HF dataset name. + subset_name: HF dataset subset. + num_shards: The number of shards, this is useful when processing large files in parallel. + caching_permitted: indicates whether this data source may be cached. + Default True. + """ + self._dataset_fn = dataset_name + self._subset_name = subset_name + self._num_shards = num_shards + + # Get dataset information + info = datasets.get_dataset_infos(dataset_name) + subset_name = subset_name + splits = list(info[subset_name].splits.keys()) + num_input_examples = {split_name: split_info.num_examples for split_name, split_info in info[subset_name].splits.items()} + self._columns = list(info[subset_name].features.keys()) + + super().__init__( + splits=splits, + num_input_examples=num_input_examples, + caching_permitted=caching_permitted) + + @property + def supports_arbitrary_sharding(self) -> bool: + return False + + def get_dataset( + self, + split: str, + shuffle: bool = True, + seed: Optional[int] = None, + shard_info: Optional[ShardInfo] = None + ) -> tf.data.Dataset: + dataset = datasets.load_dataset( + self._dataset_fn, + self._subset_name, + split=split, + ) + dataset = dataset.shard(num_shards=shard_info.num_shards, index=shard_info.index) + if shuffle: + dataset = dataset.shuffle(seed) + return dataset.to_tf_dataset( + columns=self._columns, + batch_size=1, + shuffle=False + ) + + def list_shards(self, split: str) -> Sequence[str]: + return [str(i) for i in range(self._num_shards)] + +TaskRegistry.add( + "oscar_fr_lm_objective", + source=HuggingFaceDatasetsSource( + "oscar", + "unshuffled_deduplicated_fr", + num_shards=50 + ), + preprocessors=[ + functools.partial( + preprocessors.rekey, key_map={ + "inputs": None, + "targets": "text" + }), + seqio.preprocessors.tokenize, + seqio.CacheDatasetPlaceholder(), + t5.data.preprocessors.targets_for_prefix_lm_objective, + t5.data.preprocessors.pack_prefix_lm_encoder_decoder, + ], + output_features={ + "encoder_input_tokens": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), + "decoder_target_tokens": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), + "decoder_input_tokens": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), + "encoder_segment_ids": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), + "encoder_positions": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), + "decoder_segment_ids": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), + "decoder_positions": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), + "decoder_loss_weights": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), + # All but the last stage of the preprocessing uses "targets" as the key, + # so this output feature is necessary. It is not marked required because + # the final preprocessor drops it. + "targets": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=True), + }, + metric_fns=[]) + # --- Improve sharding --- # def fully_sharded_logical_axis_rules() -> LogicalAxisRules: From a17d6c76f4769a34b771288d35137e417e66ff54 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Wed, 22 Dec 2021 17:35:34 +0100 Subject: [PATCH 2/4] Add script to run caching on DataFlow --- bigscience/beam/cache_datasets.sh | 16 ++++ bigscience/beam/setup.py | 127 ++++++++++++++++++++++++++++++ bigscience/beam/task.py | 106 +++++++++++++++++++++++++ bigscience/gins/task.py | 102 ------------------------ 4 files changed, 249 insertions(+), 102 deletions(-) create mode 100644 bigscience/beam/cache_datasets.sh create mode 100644 bigscience/beam/setup.py create mode 100644 bigscience/beam/task.py diff --git a/bigscience/beam/cache_datasets.sh b/bigscience/beam/cache_datasets.sh new file mode 100644 index 000000000..67232061b --- /dev/null +++ b/bigscience/beam/cache_datasets.sh @@ -0,0 +1,16 @@ +# Need to install seqio +# gcloud auth application-default login + +MODULE_IMPORT=beam.task +TASK_NAME=mt0.oscar +JOB_NAME=mt0oscar # the name must consist of only the characters [-a-z0-9], starting with a letter and ending with a letter or number +BUCKET=gs://bigscience-t5x # Don't know is cache needs to be task specific or not ... +PROJECT=bigscience +REGION=us-central2 # TODO: Check if we can have a generic us region +NUM_WORKERS=32 # TODO: We might need a log more than this + +seqio_cache_tasks \ + --module_import=$MODULE_IMPORT \ + --tasks=${TASK_NAME} \ + --output_cache_dir=${BUCKET}/multilingual_t0/v0.3 \ + --pipeline_options="--runner=DataflowRunner,--project=$PROJECT,--region=$REGION,--job_name=$JOB_NAME,--staging_location=$BUCKET/binaries,--temp_location=$BUCKET/tmp,--setup_file=$PWD/setup.py,--num_workers=$NUM_WORKERS,--autoscaling_algorithm=NONE,--machine_type=n1-highmem-2" \ No newline at end of file diff --git a/bigscience/beam/setup.py b/bigscience/beam/setup.py new file mode 100644 index 000000000..442d7e8f3 --- /dev/null +++ b/bigscience/beam/setup.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Setup.py module for the workflow's worker utilities. +All the workflow related code is gathered in a package that will be built as a +source distribution, staged in the staging area for the workflow being run and +then installed in the workers when they start running. +This behavior is triggered by specifying the --setup_file command line option +when running the workflow for remote execution. +""" + +# pytype: skip-file + +import subprocess +from distutils.command.build import build as _build # type: ignore + +import setuptools + + +# This class handles the pip install mechanism. +class build(_build): # pylint: disable=invalid-name + """A build command class that will be invoked during package install. + The package built using the current setup.py will be staged and later + installed in the worker using `pip install package'. This class will be + instantiated during install for this specific scenario and will trigger + running the custom commands specified. + """ + sub_commands = _build.sub_commands + [('CustomCommands', None)] + + +# Some custom command to run during setup. The command is not essential for this +# workflow. It is used here as an example. Each command will spawn a child +# process. Typically, these commands will include steps to install non-Python +# packages. For instance, to install a C++-based library libjpeg62 the following +# two commands will have to be added: +# +# ['apt-get', 'update'], +# ['apt-get', '--assume-yes', 'install', 'libjpeg62'], +# +# First, note that there is no need to use the sudo command because the setup +# script runs with appropriate access. +# Second, if apt-get tool is used then the first command needs to be 'apt-get +# update' so the tool refreshes itself and initializes links to download +# repositories. Without this initial step the other apt-get install commands +# will fail with package not found errors. Note also --assume-yes option which +# shortcuts the interactive confirmation. +# +# Note that in this example custom commands will run after installing required +# packages. If you have a PyPI package that depends on one of the custom +# commands, move installation of the dependent package to the list of custom +# commands, e.g.: +# +# ['pip', 'install', 'my_package'], +# +# TODO(BEAM-3237): Output from the custom commands are missing from the logs. +# The output of custom commands (including failures) will be logged in the +# worker-startup log. +CUSTOM_COMMANDS = [ + ['echo', 'Custom command worked!'], + ['pip', 'install', 'seqio'], + ['pip', 'install', 't5[cache-tasks]'], + ['pip', 'install', 'datasets'], + +] + + +class CustomCommands(setuptools.Command): + """A setuptools Command class able to run arbitrary commands.""" + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def RunCustomCommand(self, command_list): + print('Running command: %s' % command_list) + p = subprocess.Popen( + command_list, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + # Can use communicate(input='y\n'.encode()) if the command run requires + # some confirmation. + stdout_data, _ = p.communicate() + print('Command output: %s' % stdout_data) + if p.returncode != 0: + raise RuntimeError( + 'Command %s failed: exit code: %s' % (command_list, p.returncode)) + + def run(self): + for command in CUSTOM_COMMANDS: + self.RunCustomCommand(command) + + +# Configure the required packages and scripts to install. +# Note that the Python Dataflow containers come with numpy already installed +# so this dependency will not trigger anything to be installed unless a version +# restriction is specified. +REQUIRED_PACKAGES = [ + 'numpy', +] + +setuptools.setup( + name='beam', + version='0.0.1', + description='Cache datasets set workflow package.', + install_requires=REQUIRED_PACKAGES, + packages=setuptools.find_packages(), + cmdclass={ + # Command class instantiated and run during pip install scenarios. + 'build': build, + 'CustomCommands': CustomCommands, + }) \ No newline at end of file diff --git a/bigscience/beam/task.py b/bigscience/beam/task.py new file mode 100644 index 000000000..20b78c9c8 --- /dev/null +++ b/bigscience/beam/task.py @@ -0,0 +1,106 @@ +import functools +from typing import Sequence, Optional + +import datasets +import seqio +from seqio import TaskRegistry, ShardInfo +from t5.data import preprocessors, get_default_vocabulary +import tensorflow as tf + + +VOCABULARY = get_default_vocabulary() + +class HuggingFaceDatasetsSource(seqio.DataSource): + def __init__( + self, + dataset_name: str, + subset_name: str, + num_shards: int, + caching_permitted: bool = True + ): + """HuggingFaceDatasetsSource constructor. + Args: + dataset_name: HF dataset name. + subset_name: HF dataset subset. + num_shards: The number of shards, this is useful when processing large files in parallel. + caching_permitted: indicates whether this data source may be cached. + Default True. + """ + self._dataset_fn = dataset_name + self._subset_name = subset_name + self._num_shards = num_shards + + # Get dataset information + info = datasets.get_dataset_infos(dataset_name) + subset_name = subset_name + splits = list(info[subset_name].splits.keys()) + num_input_examples = {split_name: split_info.num_examples for split_name, split_info in info[subset_name].splits.items()} + self._columns = list(info[subset_name].features.keys()) + + super().__init__( + splits=splits, + num_input_examples=num_input_examples, + caching_permitted=caching_permitted) + + @property + def supports_arbitrary_sharding(self) -> bool: + return False + + def get_dataset( + self, + split: str, + shuffle: bool = True, + seed: Optional[int] = None, + shard_info: Optional[ShardInfo] = None + ) -> tf.data.Dataset: + dataset = datasets.load_dataset( + self._dataset_fn, + self._subset_name, + split=split, + ) + dataset = dataset.shard(num_shards=shard_info.num_shards, index=shard_info.index) + if shuffle: + dataset = dataset.shuffle(seed) + return dataset.to_tf_dataset( + columns=self._columns, + batch_size=1, + shuffle=False + ) + + def list_shards(self, split: str) -> Sequence[str]: + return [str(i) for i in range(self._num_shards)] + +TaskRegistry.add( + "oscar_fr_lm_objective", + source=HuggingFaceDatasetsSource( + "oscar", + "unshuffled_deduplicated_fr", + num_shards=1024 + ), + preprocessors=[ + functools.partial( + preprocessors.rekey, key_map={ + "inputs": None, + "targets": "text" + }), + seqio.preprocessors.tokenize, + seqio.CacheDatasetPlaceholder(), + preprocessors.targets_for_prefix_lm_objective, + preprocessors.pack_prefix_lm_encoder_decoder, + ], + output_features={ + "encoder_input_tokens": seqio.Feature(vocabulary=VOCABULARY, add_eos=False), + "decoder_target_tokens": seqio.Feature(vocabulary=VOCABULARY, add_eos=False), + "decoder_input_tokens": seqio.Feature(vocabulary=VOCABULARY, add_eos=False), + "encoder_segment_ids": seqio.Feature(vocabulary=VOCABULARY, add_eos=False), + "encoder_positions": seqio.Feature(vocabulary=VOCABULARY, add_eos=False), + "decoder_segment_ids": seqio.Feature(vocabulary=VOCABULARY, add_eos=False), + "decoder_positions": seqio.Feature(vocabulary=VOCABULARY, add_eos=False), + "decoder_loss_weights": seqio.Feature(vocabulary=VOCABULARY, add_eos=False), + # All but the last stage of the preprocessing uses "targets" as the key, + # so this output feature is necessary. It is not marked required because + # the final preprocessor drops it. + "targets": seqio.Feature(vocabulary=VOCABULARY, add_eos=True), + }, + metric_fns=[] +) \ No newline at end of file diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index 80049cb57..65e77d7a0 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -1,18 +1,11 @@ import functools -from typing import Optional, Sequence -import datasets import seqio -import t5 -from seqio.dataset_providers import _validate_args, ShardInfo from t5.data import preprocessors, get_default_vocabulary from t5.data.preprocessors import select_random_chunk, reduce_concat_tokens, split_tokens -import tensorflow as tf from promptsource import seqio_tasks -from t5x.partitioning import LogicalAxisRules - # --- Seqio --- seqio.add_global_cache_dirs([ 'gs://bigscience-t5x/seqio_cached_tasks', @@ -173,101 +166,6 @@ def full_lm(dataset, sequence_length, output_features): default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), ) # eval mixture does not need to be capped - -class HuggingFaceDatasetsSource(seqio.DataSource): - def __init__( - self, - dataset_name: str, - subset_name: str, - num_shards: int, - caching_permitted: bool = True - ): - """HuggingFaceDatasetsSource constructor. - Args: - dataset_name: HF dataset name. - subset_name: HF dataset subset. - num_shards: The number of shards, this is useful when processing large files in parallel. - caching_permitted: indicates whether this data source may be cached. - Default True. - """ - self._dataset_fn = dataset_name - self._subset_name = subset_name - self._num_shards = num_shards - - # Get dataset information - info = datasets.get_dataset_infos(dataset_name) - subset_name = subset_name - splits = list(info[subset_name].splits.keys()) - num_input_examples = {split_name: split_info.num_examples for split_name, split_info in info[subset_name].splits.items()} - self._columns = list(info[subset_name].features.keys()) - - super().__init__( - splits=splits, - num_input_examples=num_input_examples, - caching_permitted=caching_permitted) - - @property - def supports_arbitrary_sharding(self) -> bool: - return False - - def get_dataset( - self, - split: str, - shuffle: bool = True, - seed: Optional[int] = None, - shard_info: Optional[ShardInfo] = None - ) -> tf.data.Dataset: - dataset = datasets.load_dataset( - self._dataset_fn, - self._subset_name, - split=split, - ) - dataset = dataset.shard(num_shards=shard_info.num_shards, index=shard_info.index) - if shuffle: - dataset = dataset.shuffle(seed) - return dataset.to_tf_dataset( - columns=self._columns, - batch_size=1, - shuffle=False - ) - - def list_shards(self, split: str) -> Sequence[str]: - return [str(i) for i in range(self._num_shards)] - -TaskRegistry.add( - "oscar_fr_lm_objective", - source=HuggingFaceDatasetsSource( - "oscar", - "unshuffled_deduplicated_fr", - num_shards=50 - ), - preprocessors=[ - functools.partial( - preprocessors.rekey, key_map={ - "inputs": None, - "targets": "text" - }), - seqio.preprocessors.tokenize, - seqio.CacheDatasetPlaceholder(), - t5.data.preprocessors.targets_for_prefix_lm_objective, - t5.data.preprocessors.pack_prefix_lm_encoder_decoder, - ], - output_features={ - "encoder_input_tokens": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), - "decoder_target_tokens": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), - "decoder_input_tokens": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), - "encoder_segment_ids": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), - "encoder_positions": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), - "decoder_segment_ids": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), - "decoder_positions": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), - "decoder_loss_weights": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=False), - # All but the last stage of the preprocessing uses "targets" as the key, - # so this output feature is necessary. It is not marked required because - # the final preprocessor drops it. - "targets": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=True), - }, - metric_fns=[]) - # --- Improve sharding --- # def fully_sharded_logical_axis_rules() -> LogicalAxisRules: From e986ae9b5e0b7b67c36283f5fa748ea68d13d47d Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Wed, 22 Dec 2021 18:16:50 +0100 Subject: [PATCH 3/4] Add a TODO to figure out how the caching works in DataFlow --- bigscience/beam/cache_datasets.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bigscience/beam/cache_datasets.sh b/bigscience/beam/cache_datasets.sh index 67232061b..005bfbaa2 100644 --- a/bigscience/beam/cache_datasets.sh +++ b/bigscience/beam/cache_datasets.sh @@ -7,7 +7,9 @@ JOB_NAME=mt0oscar # the name must consist of only the characters [-a-z0-9], star BUCKET=gs://bigscience-t5x # Don't know is cache needs to be task specific or not ... PROJECT=bigscience REGION=us-central2 # TODO: Check if we can have a generic us region -NUM_WORKERS=32 # TODO: We might need a log more than this +NUM_WORKERS=1000 # TODO: We might need a log more than this + +# TODO: One thing we need to figure out is how does it handle HF datasets cache. If all workers need to download it, it's a big no no. seqio_cache_tasks \ --module_import=$MODULE_IMPORT \ From 8151c3348b028f80cdc908e6701c277140648c5f Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Wed, 22 Dec 2021 18:50:24 +0100 Subject: [PATCH 4/4] Update gcp region --- bigscience/beam/cache_datasets.sh | 2 +- bigscience/beam/task.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bigscience/beam/cache_datasets.sh b/bigscience/beam/cache_datasets.sh index 005bfbaa2..b1066b26f 100644 --- a/bigscience/beam/cache_datasets.sh +++ b/bigscience/beam/cache_datasets.sh @@ -6,7 +6,7 @@ TASK_NAME=mt0.oscar JOB_NAME=mt0oscar # the name must consist of only the characters [-a-z0-9], starting with a letter and ending with a letter or number BUCKET=gs://bigscience-t5x # Don't know is cache needs to be task specific or not ... PROJECT=bigscience -REGION=us-central2 # TODO: Check if we can have a generic us region +REGION=us-central1 # TODO: Check if we can have a generic us region NUM_WORKERS=1000 # TODO: We might need a log more than this # TODO: One thing we need to figure out is how does it handle HF datasets cache. If all workers need to download it, it's a big no no. diff --git a/bigscience/beam/task.py b/bigscience/beam/task.py index 20b78c9c8..3da2c3f48 100644 --- a/bigscience/beam/task.py +++ b/bigscience/beam/task.py @@ -23,7 +23,7 @@ def __init__( dataset_name: HF dataset name. subset_name: HF dataset subset. num_shards: The number of shards, this is useful when processing large files in parallel. - caching_permitted: indicates whether this data source may be cached. + caching_permitted: indicates whether this data source may be cached by seqio. Default True. """ self._dataset_fn = dataset_name