diff --git a/.pylintrc b/.pylintrc index d6f8a5d6c..222bdf6cb 100644 --- a/.pylintrc +++ b/.pylintrc @@ -638,7 +638,7 @@ callbacks=cb_, dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. -ignored-argument-names=_.*|^ignored_|^unused_ +ignored-argument-names=_.*|^ignored_|^unused_|kwargs # Tells whether we should check for unused import in __init__ files. init-import=no diff --git a/architecture_records/003-generic-tracker-framework.md b/architecture_records/003-generic-tracker-framework.md index f70b2df15..5e3f63eae 100644 --- a/architecture_records/003-generic-tracker-framework.md +++ b/architecture_records/003-generic-tracker-framework.md @@ -1,4 +1,4 @@ -# Resource Scanner +# Generic Tracker Framework **Deciders(s)**: Sukriti Sharma (sukriti.sharma4@ibm.com), Alexander Brooks (alex.brooks@ibm.com), Raghu Ganti (rganti@us.ibm.com), Dushyant Behl (dushyantbehl@in.ibm.com), Ashok Pon Kumar (ashokponkumar@in.ibm.com) diff --git a/architecture_records/004-datapreprocessor.md b/architecture_records/004-datapreprocessor.md new file mode 100644 index 000000000..df9c1267b --- /dev/null +++ b/architecture_records/004-datapreprocessor.md @@ -0,0 +1,422 @@ +# Data Pre Processor Design For fms-hf-tuning + +**Deciders(s)**: Sukriti Sharma (sukriti.sharma4@ibm.com), Will Johnson (Will.Johnson@ibm.com) , Abhishek Maurya (maurya.abhishek@ibm.com), Yu Chin Fabian Lim (flim@sg.ibm.com), Dushyant Behl (dushyantbehl@in.ibm.com), Ashok Pon Kumar (ashokponkumar@in.ibm.com) + +**Date (YYYY-MM-DD)**: 2024-03-06 + +**Obsoletes ADRs**: NA + +**Modified By ADRs**: NA + +**Relevant Issues**: [1] + +- [Summary and Objective](#summary-and-objective) + - [Motivation](#motivation) + - [User Benefit](#user-benefit) +- [Decision](#decision) + - [Alternatives Considered](#alternatives-considered) +- [Consequences](#consequences) +- [Detailed Design](#detailed-design) + +## Summary and Objective + +The primary objective of the `DataPreProcessor` design for fms-hf-tuning is to provide a unified yet powerful interface for handling diverse data formats and configurations. +This interface should cater to various user expertise levels, enabling basic users to easily load and process data, while allowing advanced users to customize their data handling. + +### Key Goals: +1. **Broad Data Format Support**: Allow datasets in formats such as Arrow, Parquet, and CSV. +1. **Compatibility with Multiple Datasets and Files**: Enable multiple files per dataset and interleaving or mixing of datasets. +1. **Support for Different Data Modalities**: Include images, audio, and text data, along with modality-specific preprocessing options. +1. **User-Focused Configurations**: Provide simple data loading for regular users, while enabling advanced configurations for expert users. +1. **Template-Based Preprocessing**: Support jinja template rendering, where necessary, for template-dependent preprocesing requirements. + +### Motivation + +The main motivation for this ADR stems from the fact that fms-hf-tuning is being used by many teams for a diverse set of use cases which are not currently supported in the library. To be precise, currently in the library for data preprocessing we currently take two primary arguments `training_data_path` and `validataion_data_path` which take in a single file location for a dataset. +A user can currently pass in +1. a pretokenized json(l) dataset via + ``` + --training_data_path + ``` +1. a preprocessed json(l) formats with a single sequence and a specified `response_template` to use for masking on completion. + ``` + --training_data_path --dataset_text_field --response_template + ``` +1. a json(l) dataset and a `data_formatter_template` to use the formatting function on fly. + ``` + --training_data_path --data_formatter_template <'template'> + ``` +1. a json(l) dataset with `input` and `output` fields, names hardcoded and cannot be changed. + ``` + --training_data_path .json + ``` + +The first motivation for a change is the requirements from users asking for different data formats, current code only supports json while there are teams which are training using Parquet and Arrow format so they require additional data format support. + +Also use cases from teams require multiple datasets and even multiple data files in a dataset. + +Further requirements from teams is to have a way to interleave datasets at run time by specifying static weights to mix different datasets which is also not supported by the code yet. + +Finally other requirements are to have preprocesing support for multiple modalities of data (starting with Image first) and have support for advanced preprocesing like jinja based template rendering of the dataset before consumption. + +All these requirements are new and are currently not supported by the library which motivated us to propose a change in the design of data preprocesing in this library to incorporate these and potentially any new changes in one go. + +### User Benefit + +Users will benefit from the additional argument which allows them to pass a single [`data_config`](#our-considerations-for-the-design) file specifying how to preprocess their dataset. +Our data config file will extend users the capabilities to, +1. Pass multiple data files and multiple datasets. +1. Specify static weights in the configuration to interleave datasets. +1. Define preprocessing routines to apply on the data and in which order + +This will make the process of handling custom datasets which might require rendering jinja template or processing image data way much easier. + +We do not require users to learn the specification of the additional `data_config` file, as the existing arguments to process dataset which are present in the code [`tuning.config.configs.DataArguments`](https://github.com/foundation-model-stack/fms-hf-tuning/blob/398c2a8fe26d734344240555585d95e05299faa8/tuning/config/configs.py#L67) will not be deprecated in this version and users can keep using the same data arguments for use cases being served by the library. + +At the very least a user not well versed with the `data_config` will be able to pass in for e.g. a pre-tokenized pyarrow dataset + +``` +--training_data_path .pyarrow +``` + +And at full length they can specify multi-file, multi dataset configuration which can process dataset according to the config specified like, + + +``` +... + - name: dataset1 + sampling: + ratio: 0.3 + data_paths: + - /data/stackoverflow-kubectl_posts + - /data/stackoverflow-kubernetes_posts + - /data/stackoverflow-openshift_posts + data_handlers: + - name: render_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + jinja_template: "{}" +``` +A small part of the `data-config` spec which is provided in detailed in the section below. + +## Decision + +Some terminology before we move ahead + + + + + + + + + + + + + + + + + + +
User PersonaDescription
Simple UserA user who uses this library to train models using a single dataset, passing it via a single command line argument.
Advanced UserA user with a deep understanding of datasets, who knows how to apply specific preprocessing and mixing techniques during training.
Intermediate UserA user who works with custom datasets but lacks full knowledge of data processing, relying on advanced users for guidance to fulfill their use case.
+ +Please note that most of the users of product here would fall into the simple user category while advanced and intermediate users are researchers looking to use our library for diverse set of use cases. + +### Our considerations for the design + +1. Allow advanced users to use full power of the HuggingFace library as much as possible without recreating the same. +1. Allow advanced users to specify custom data preprocessor pipeline in an easy way. +1. Ensure the single design can handle these and many more use cases without major changes. +1. Design for Advanced users while simplify for simple users. + +We propose to allow advanced users to specify a full spec which exposes data preprocessing API provided by the HF library directly to them to be able to fully utilize the interface. + +The proposed input spec which user specifies as `data_config` on how to pass information for such preprocessing is + +``` +datapreprocessor: + type: default +datasets: + - name: dataset1 + sampling: + ratio: 0.3 + data_paths: + - /data/stackoverflow-kubectl_posts + - /data/stackoverflow-kubernetes_posts + - /data/stackoverflow-openshift_posts + data_handlers: + - name: tokenize + arguments: + remove_columns: all + batched: false + - name: dataset2 + sampling: + ratio: 0.4 + data_paths: + - /data/stackoverflow-kubectl_posts + - /data/stackoverflow-kubernetes_posts + - /data/stackoverflow-openshift_posts + data_handlers: + - name: render_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + jinja_template: "{}" + - name: tokenize + arguments: + remove_columns: all + batched: false +``` + +To iterate again, here our goal is not to re-implement the functionality provided by HuggingFace but rather have a clean interface using a config where advanced users can use things like Iterable datasets or Interleaving datasets and perform custom preprocessing like applying jinja templates etc in an easy way. + +In this spec, at top level we have the `Dataprocessor` config which contains just one field `type` which is set to `default`. This is done to ensure any future top level `dataprocessor` configs will go into this block. Users need not touch or provide this as the `default` is automatically selected. + +The second block here is where users will list multiple `datasets` and each dataset will contain information on how to process it. We allow arguments like `sampling` for users to specify sampling ratios while [`interleaving datasets`](https://huggingface.co/docs/datasets/en/process#interleave) to use API like [`interleave_datasets`](https://huggingface.co/docs/datasets/v3.1.0/en/package_reference/main_classes#datasets.interleave_datasets) by HuggingFace. + +The most powerful feature of this block is `data_handlers`. Here we allow users to specify a list of routines to apply on the dataset at the time of preprocessing. A `data_handler` is a [`map`](https://huggingface.co/docs/datasets/en/process#map) operation performed on the dataset to which a user can further pass informational arguments. We expose the full set of arguments of HF [`Dataset.map`](https://huggingface.co/docs/datasets/v3.1.0/en/package_reference/main_classes#datasets.Dataset.map) operation here to the user as `kwargs` of a handler. + +As example in `dataset2` the data handler is requesting to apply a `render_template` function before tokenization on the dataset which processes the dataset and renders the `jinja template` specified as `fn_kwargs.jinja_template`, rest of the arguments like `remove_column` and `batched` are just HF Map API arguments. + +``` +- name: dataset2 + sampling: + ratio: 0.4 + data_paths: + - /data/stackoverflow-kubectl_posts + - /data/stackoverflow-kubernetes_posts + - /data/stackoverflow-openshift_posts + data_handlers: + - name: render_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + jinja_template: "{}" + - name: tokenize + arguments: + remove_columns: all + batched: false +``` + +By allowing the users to specify data handlers like this we allow them to use full Hugging Face API and at the same time specify preprocessing routines in a fixed order. The handlers list specify a [`DAG`](https://en.wikipedia.org/wiki/Directed_acyclic_graph) of operations to apply on the dataset and will be executed by the code in that order. + +Furthermore this design allows flexibility to be extended to any upcoming usecase because any operation to be executed on the dataset can be broken down into function execution implemented as data handlers. + +This makes our spec a complete solution for advanced users of the library, who have custom preprocessing needs. Allowing them to specify complete preprocessing operations to be applied to the dataset via a config file. + +Finally, with this spec we do not want to break the functionality for the simple users of the library. A simple user which wants to just use the library with a single dataset like today can pass the same dataset via `--training_data_path --validataion_data_path ` arguments. + +Infact we do not change the behavior currently supported by any of the `tuning.config.configs.DataArguments` arguments hence allowing the simple users of the library to continue using the library as is. + +### Performance Considerations + +Since this design allows complex preprocessing of the dataset on fly, the design should incorporate performance measures to ensure that the system is not performing too slow or spending too much time while preprocessing the dataset to affect tuning/training time. + +The goal that we have here is to not be slower than the HuggingFace library which our whole design is based upon, in this sense we also imagine any performance improvements that we come across to be contributed back to HF library to keep our design simple and not reimplement stuff. + +#### Handling Large Dataset + +Our main reason for using HF [Map](https://huggingface.co/docs/datasets/en/process#map) heavily for data preprocessing is that for large datasets which are generally loaded as `IterableDatasets` the MAP API automatically performs [`lazy map operations`](https://huggingface.co/docs/datasets/en/about_mapstyle_vs_iterable#eager-data-processing-and-lazy-data-processing) and hence doesn't produce too much overhead while training. + +#### Caching intermediate dataset + +Hugging Face caches intermediate map operations which makes replay of our data preprocessor easier if same map parameters and operations are applied. If the file system is an issue we have two considerations, + +1. Keep intermediate datasets in memory while preprocessing using [`keep_in_memory=True`](https://huggingface.co/docs/datasets/v3.1.0/en/package_reference/main_classes#datasets.Dataset.map.keep_in_memory), for large datasets and Iterable datasets we assume this to be for mini batches. +1. Disable caching in file system and make it configurable by [`disable_caching()`](https://huggingface.co/docs/datasets/en/cache#enable-or-disable-caching) API from HuggingFace. + +These considerations can be made dynamicallly or can be passed via users as we allow any `kwargs` to be passed to the map operations. + +## Alternatives Considered + +### Letting users process their own data and pass file(s) directly to this library. + +A simple alternative to avoid all this is to have the users process their own data, this is also in lines of the fact that most workloads contain preprocessed data which is used by simple users as is for their tuning/training. + +The reason to have this design is that many users coming to this library have advanced set of use cases. As stated in the motivation we are getting ever increased demand from researchers looking to use this library are looking for features like `jinja template` rendering, image data processing, mixing and merging datasets. While this can be done at user level most users are not looking to write code to do all this preprocessing but use tools which implement them to perform these tasks. +Leaving all users to write their own preprocessing logic can also lead to code duplication across many teams which is something we want to avoid. + +More importantly as stated in the motivation we are getting ever increased demand from users who want to use this library directly with their dataset and have quick roundtrip for testing. This design allows users to specify simple parameters in the config and test for complex usecases easily. + +### Passing all datasets we take to the HuggingFace SFTTrainer API and let it handle them without preprocessing at our end. + +Another alternative we have is to take the `dataset` input to this library and pass it directly to the trainer `SFTTrainer` in our case directly and let it handle loading and preprocessing the dataset. + +[SFTTrainer](https://huggingface.co/docs/trl/v0.12.1/en/sft_trainer#trl.SFTTrainer) supports specifying the `train_dataset` and `eval_dataset` for both of which it supports Iterable datasets along with normal datasets allowing us to pass a large dataset supported via streaming. + +Please note that even in this case users will need to tell us that the dataset is large and is to be loaded via `streaming=True` because the argument which tells HF to load the dataset in Iterable mode or standard mode is passed to [`load_dataset`](https://huggingface.co/docs/datasets/v3.1.0/en/package_reference/loading_methods#datasets.load_dataset) + +``` +from datasets import load_dataset +train_ds = load_dataset('imdb', split='train', streaming=True) +``` + +Additionally, `SFTTrainer` has support for [data formatting function](https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support). Users can pass a `formatting_function` directly to `SFTtrainer` which formats the dataset for them, + +``` +def formatting_prompts_func(example): + output_texts = [] + for i in range(len(example['question'])): + text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}" + output_texts.append(text) + return output_texts + +trainer = SFTTrainer( + model, + args=training_args, + train_dataset=dataset, + formatting_func=formatting_prompts_func, +) + +trainer.train() +``` +Taken from [HuggingFace docs](https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support) + +As our library is a wrapper on top of HF we cannot directly allow users to pass a custom formatting function and +our `data_handler` design can also support formatting dataset in a similar way to `formatting function` where users specify just name of the handler and we apply formatting on our end. The design for `data_handler` that we have is a superset of this feature which is more flexible and can support many more use cases. + +## Consequences + +### Arguments Required +In this design, apart from the `data_config` spec users will also need to pass the `--response_template` argument. This is because the `DataCollator` functionality of this library is not being touched by our design. + +Also users who process JSON dataset via our interface need to specify `--dataset_text_field` which is inferred from the `DataArguments` for now and not passed inside the data_config to ensure the simple interface remains same. + +We also plan to add a new argument to `tuning.config.configs.DataArguments` which takes in the `data_config` file as input. like, +``` +@dataclass +class DataArguments: +... + data_config_file: str = field( + default=None, + metadata={ + "help": "data_config file which specifies the data preprocessing logic to apply.\ + Supports both JSON and YAML based config files" + }, + ) +``` + +### Understanding the spec +With this design we have tried to keep our design simple and close to the HF library as much as possible, e.g. exposing the same map `kwargs` that HF has in our `data_handlers`. + +Despite this advanced users will need to understand the spec to be able to write it properly. + +Advanced users will also need to educate themselves on the data handlers already present in the code. Since the data handlers are selected based on their name we need to ensure the documentation contains complete information on what different data handlers are present and how to use them in the `data_config`. + +### Sharing config files +We currently do not propose anything on how advanced users share the `data_config` files created by them with Intermediate and Simple users. This is left outside the scope of our library. + +### Simple User Perspective + +As mentioned above we are retaining the full functionality supported by `tuning.config.configs.DataArguments` which means simple users can continue using the library by passing a simple dataset via `--training_data_path` and use case specific arguments like `--data_formatter_template` as they please and the code will internally handle how to map these to the `data_config` spec. + +### Intermediate User Perspective +Our perspective is that the advanced users will create config files for data preprocessing and the intermediate users can use these existing configs and modify them according to their preference to get the desired result. + +## Detailed Design + +### The proposed design to implement support for this spec is follows, + +Data Pre Processor abstract class + +``` +class DataPreProcessor(ABC): + + tokenizer = None + model_name_or_path = None + block_size = None + data_config: DataConfig = None + data_handlers: Dict[str, Callable] = None + + def __init__(self, dataconfig: DataConfig, tokenizer, model_name_or_path, block_size): + self.data_config = dataconfig + self.tokenizer = tokenizer + self.model_name_or_path = model_name_or_path + self.block_size = block_size + self.data_handlers = {} + + def register_data_handler(self, name: str, d: Callable): + self.data_handlers[name] = d + + @abstractmethod + def process_data_config(self, data_config: DataConfig): + pass +``` + +At the top level we propose to have this `class DataPreProcessor` which is an abstract class and requires functions to process the data config proposed above. + +The data pre processor needs to support custom data handlers. In the library for simple use cases we will provide predefined data handlers which need to be registered with the top level class using the +call `DataPreProcessor.register_data_handler`. + +The simple use cases will be handled using these data handlers and which data handler to choose will depend on the use case chosen from data args (same as the current code). + +## How are handlers provided and registered - + +Data handlers are python callables which can be called on single/few samples of data and can perform things like tokenising the data, applying tools like jinja template or even things like encoding or decoding multi modal formats like images/audio for processing by the model. + +The abstract datapreprocessor class provides a way to register datahandler against a `name` which is a string. The data handler config `DataHandlerConfig` taken by `execute_data_handlers` represents a DAG of data handling routines which are to be executed on the data. + +For standard HF API you can think of these as the HF Processing routines which could be Map/Filter/Select operations. We implement most of the routines as map based operations. The current code also implements functionality like tokenization of data or data formatting via map e.g. +`tuning/utils/preprocessing_utils.py::get_preprocessed_dataset` such functionality can be retained as predefined data handlers. + +The implementation is flexible enough for very advanced users to specify their own implementation of data handling routines by importing fms-hf-tuning and extending the preprocessing by calling `register_data_handler` on the preprocessor. This is left for advanced users of the library and not for simple users. + +# Implementation of the default Data Preprocessor. + +The default data preprocessor implemented as an instance of the `DataPreProcessor` class uses HF APIs where ever possible +to miminize reimplementation of code. + +The HF datapreprocessor processes different type of files via its `load_dataset` factory. If not supported automatically via this, we can look to extend the factory to use an other type of interest via +`Dataset.from_generator()` functionality. + +This also means that any implementation like `get_json_object` which load `json(l)` and then return a custom json dict +can be implemented as data handlers. + +### Interleaving datasets + +In case of multiple datasets the user can request how the datasets are to be interleaved. +The probabilities specified by users in the config `sampling.ratio` can be collected from individual datasets and passed to +[`datasets.interleave_datasets`](https://huggingface.co/docs/datasets/v3.0.1/en/package_reference/main_classes#datasets.interleave_datasets). + +### Streaming datasets + +In HuggingFace the `streaming` argument can be handled by using `IterableDatasets` instead of standard `Datasets`. +HF provides same APIs like `datasets.interleave_datasets` over the `Iterable` datasets as well. + +Further important thing to note is in case of HF, if we use hugging face the `map` functionality which we use to implement data handling is handled in a lazy fashion meaning we don't need to handle the data handlers in a different way for streaming data. [More Information on HF Page.](https://huggingface.co/docs/datasets/en/about_mapstyle_vs_iterable#eager-data-processing-and-lazy-data-processing) + +## Handling data collators. + +Data collators specifically for TRL use cases like chat based interactions which apply chat templates and proper attention masking on the tokenized data like in the case of `DataCollatorForCompletionOnlyLM` handle a specific functionality on the data. + +In this design our approach is to pass data collators from hugging face API directly to SFTTrainer. + +In the current code path, collators are collected by `get_data_collator` functionality and passed to `SFTTrainer`. We can retain the same functionality and keep the design simpler. + +The job of the data pre processor is to provide a single interface over multiple datasets in the config while keeping a collator like this means we will keep the collator same across all datasets but keeps the design simpler. + +## Handling Multi Modal Data. + +HF does provide support for handling [image datasets](https://huggingface.co/docs/datasets/en/image_process) and [audio datasets](https://huggingface.co/docs/datasets/en/audio_load) which can be utilized by us in our HF datapreprocessor. + +The functionality listed by HF in implementing the use of image and audio datasets is `map` based functions to perform resize, encoding and other such operations on the dataset (see the link above). + +This means the image and audio multi modal datasets will be compatible with our data handler routines. Once we implement the data handler routine processing, we will allow users to train with multi modal datasets too. + +# Implementing stages. +1. Stage 1: + * Refactoring the code in `fms-hf-tuning` into the abstract data class and adding support for preliminary data handling routines. + This will automatically enable support for multi modal data which is our priority. + Note at this stage it might be wise to have two side by side implementations, i.e. not deleting the existing implementation. +1. State 2: + * Implementing `streaming` data or `iterable` dataset support for the HF datapreprocessor implementation. + * Data handling support for streaming data +1. State 3: + * Identify and add any other required predefined data handlers. + * Phase out the old implementation in support of the new one. diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index b6acf7eb3..d25554fe6 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -54,13 +54,13 @@ from tuning.utils.import_utils import is_fms_accelerate_available # for some reason the CI will raise an import error if we try to import -# these from tests.data +# these from tests.artifacts.testdata TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( - os.path.dirname(__file__), "../data/twitter_complaints_json.json" + os.path.dirname(__file__), "../artifacts/testdata/twitter_complaints_json.json" ) TWITTER_COMPLAINTS_TOKENIZED = os.path.join( os.path.dirname(__file__), - "../data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json", + "../artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json", ) # pylint: disable=import-error diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py new file mode 100644 index 000000000..f9b766be6 --- /dev/null +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -0,0 +1,30 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpful datasets for configuring individual unit tests. +""" +# Standard +import os + +### Constants used for data +PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__)) +APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" +) +PRETOKENIZE_JSON_DATA_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml" +) +TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking.yaml" +) diff --git a/tests/artifacts/predefined_data_configs/apply_custom_template.yaml b/tests/artifacts/predefined_data_configs/apply_custom_template.yaml new file mode 100644 index 000000000..4aab0d76a --- /dev/null +++ b/tests/artifacts/predefined_data_configs/apply_custom_template.yaml @@ -0,0 +1,14 @@ +dataprocessor: + type: default +datasets: + - name: apply_custom_data_template + data_paths: + - "FILE_PATH" + data_handlers: + - name: apply_custom_data_formatting_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + dataset_text_field: "dataset_text_field" + dataset_template: "dataset_template" \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/pretokenized_json_data.yaml b/tests/artifacts/predefined_data_configs/pretokenized_json_data.yaml new file mode 100644 index 000000000..833173dea --- /dev/null +++ b/tests/artifacts/predefined_data_configs/pretokenized_json_data.yaml @@ -0,0 +1,6 @@ +dataprocessor: + type: default +datasets: + - name: pretokenized_dataset + data_paths: + - "FILE_PATH" \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml new file mode 100644 index 000000000..d8fc16eec --- /dev/null +++ b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml @@ -0,0 +1,14 @@ +dataprocessor: + type: default +datasets: + - name: text_dataset_input_output_masking + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_input_masking + arguments: + remove_columns: all + batched: false + fn_kwargs: + input_field: "INPUT" + output_field: "OUTPUT" \ No newline at end of file diff --git a/tests/data/__init__.py b/tests/artifacts/testdata/__init__.py similarity index 100% rename from tests/data/__init__.py rename to tests/artifacts/testdata/__init__.py diff --git a/tests/data/empty_data.json b/tests/artifacts/testdata/empty_data.json similarity index 100% rename from tests/data/empty_data.json rename to tests/artifacts/testdata/empty_data.json diff --git a/tests/data/malformatted_data.json b/tests/artifacts/testdata/malformatted_data.json similarity index 100% rename from tests/data/malformatted_data.json rename to tests/artifacts/testdata/malformatted_data.json diff --git a/tests/data/trainercontroller/__init__.py b/tests/artifacts/testdata/trainercontroller/__init__.py similarity index 100% rename from tests/data/trainercontroller/__init__.py rename to tests/artifacts/testdata/trainercontroller/__init__.py diff --git a/tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml b/tests/artifacts/testdata/trainercontroller/epoch-level-eval-loss-patience.yaml similarity index 100% rename from tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml rename to tests/artifacts/testdata/trainercontroller/epoch-level-eval-loss-patience.yaml diff --git a/tests/data/trainercontroller/epoch-level-eval-loss.yaml b/tests/artifacts/testdata/trainercontroller/epoch-level-eval-loss.yaml similarity index 100% rename from tests/data/trainercontroller/epoch-level-eval-loss.yaml rename to tests/artifacts/testdata/trainercontroller/epoch-level-eval-loss.yaml diff --git a/tests/data/trainercontroller/epoch-level-training-loss.yaml b/tests/artifacts/testdata/trainercontroller/epoch-level-training-loss.yaml similarity index 100% rename from tests/data/trainercontroller/epoch-level-training-loss.yaml rename to tests/artifacts/testdata/trainercontroller/epoch-level-training-loss.yaml diff --git a/tests/data/trainercontroller/exposed_metrics.yaml b/tests/artifacts/testdata/trainercontroller/exposed_metrics.yaml similarity index 100% rename from tests/data/trainercontroller/exposed_metrics.yaml rename to tests/artifacts/testdata/trainercontroller/exposed_metrics.yaml diff --git a/tests/data/trainercontroller/incorrect_source_event_exposed_metrics.yaml b/tests/artifacts/testdata/trainercontroller/incorrect_source_event_exposed_metrics.yaml similarity index 100% rename from tests/data/trainercontroller/incorrect_source_event_exposed_metrics.yaml rename to tests/artifacts/testdata/trainercontroller/incorrect_source_event_exposed_metrics.yaml diff --git a/tests/data/trainercontroller/log_controller.yaml b/tests/artifacts/testdata/trainercontroller/log_controller.yaml similarity index 100% rename from tests/data/trainercontroller/log_controller.yaml rename to tests/artifacts/testdata/trainercontroller/log_controller.yaml diff --git a/tests/data/trainercontroller/loss_custom_metric.yaml b/tests/artifacts/testdata/trainercontroller/loss_custom_metric.yaml similarity index 100% rename from tests/data/trainercontroller/loss_custom_metric.yaml rename to tests/artifacts/testdata/trainercontroller/loss_custom_metric.yaml diff --git a/tests/data/trainercontroller/loss_custom_operation.yaml b/tests/artifacts/testdata/trainercontroller/loss_custom_operation.yaml similarity index 100% rename from tests/data/trainercontroller/loss_custom_operation.yaml rename to tests/artifacts/testdata/trainercontroller/loss_custom_operation.yaml diff --git a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml b/tests/artifacts/testdata/trainercontroller/loss_custom_operation_invalid_action.yaml similarity index 100% rename from tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml rename to tests/artifacts/testdata/trainercontroller/loss_custom_operation_invalid_action.yaml diff --git a/tests/data/trainercontroller/loss_invalid_metric.yaml b/tests/artifacts/testdata/trainercontroller/loss_invalid_metric.yaml similarity index 100% rename from tests/data/trainercontroller/loss_invalid_metric.yaml rename to tests/artifacts/testdata/trainercontroller/loss_invalid_metric.yaml diff --git a/tests/data/trainercontroller/loss_invalid_operation.yaml b/tests/artifacts/testdata/trainercontroller/loss_invalid_operation.yaml similarity index 100% rename from tests/data/trainercontroller/loss_invalid_operation.yaml rename to tests/artifacts/testdata/trainercontroller/loss_invalid_operation.yaml diff --git a/tests/data/trainercontroller/loss_invalid_operation_action.yaml b/tests/artifacts/testdata/trainercontroller/loss_invalid_operation_action.yaml similarity index 100% rename from tests/data/trainercontroller/loss_invalid_operation_action.yaml rename to tests/artifacts/testdata/trainercontroller/loss_invalid_operation_action.yaml diff --git a/tests/data/trainercontroller/loss_invalid_trigger.yaml b/tests/artifacts/testdata/trainercontroller/loss_invalid_trigger.yaml similarity index 100% rename from tests/data/trainercontroller/loss_invalid_trigger.yaml rename to tests/artifacts/testdata/trainercontroller/loss_invalid_trigger.yaml diff --git a/tests/data/trainercontroller/loss_on_threshold.yaml b/tests/artifacts/testdata/trainercontroller/loss_on_threshold.yaml similarity index 100% rename from tests/data/trainercontroller/loss_on_threshold.yaml rename to tests/artifacts/testdata/trainercontroller/loss_on_threshold.yaml diff --git a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml b/tests/artifacts/testdata/trainercontroller/loss_on_threshold_with_trainer_state.yaml similarity index 100% rename from tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml rename to tests/artifacts/testdata/trainercontroller/loss_on_threshold_with_trainer_state.yaml diff --git a/tests/data/trainercontroller/loss_unavailable_metric.yaml b/tests/artifacts/testdata/trainercontroller/loss_unavailable_metric.yaml similarity index 100% rename from tests/data/trainercontroller/loss_unavailable_metric.yaml rename to tests/artifacts/testdata/trainercontroller/loss_unavailable_metric.yaml diff --git a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml b/tests/artifacts/testdata/trainercontroller/loss_with_invalid_type_rule.yaml similarity index 100% rename from tests/data/trainercontroller/loss_with_invalid_type_rule.yaml rename to tests/artifacts/testdata/trainercontroller/loss_with_invalid_type_rule.yaml diff --git a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml b/tests/artifacts/testdata/trainercontroller/loss_with_malicious_input_rule.yaml similarity index 100% rename from tests/data/trainercontroller/loss_with_malicious_input_rule.yaml rename to tests/artifacts/testdata/trainercontroller/loss_with_malicious_input_rule.yaml diff --git a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml b/tests/artifacts/testdata/trainercontroller/loss_with_malicious_os_rule.yaml similarity index 100% rename from tests/data/trainercontroller/loss_with_malicious_os_rule.yaml rename to tests/artifacts/testdata/trainercontroller/loss_with_malicious_os_rule.yaml diff --git a/tests/data/trainercontroller/non-decreasing-training-loss.yaml b/tests/artifacts/testdata/trainercontroller/non-decreasing-training-loss.yaml similarity index 100% rename from tests/data/trainercontroller/non-decreasing-training-loss.yaml rename to tests/artifacts/testdata/trainercontroller/non-decreasing-training-loss.yaml diff --git a/tests/data/trainercontroller/on-save.yaml b/tests/artifacts/testdata/trainercontroller/on-save.yaml similarity index 100% rename from tests/data/trainercontroller/on-save.yaml rename to tests/artifacts/testdata/trainercontroller/on-save.yaml diff --git a/tests/data/trainercontroller/thresholded-training-loss.yaml b/tests/artifacts/testdata/trainercontroller/thresholded-training-loss.yaml similarity index 100% rename from tests/data/trainercontroller/thresholded-training-loss.yaml rename to tests/artifacts/testdata/trainercontroller/thresholded-training-loss.yaml diff --git a/tests/data/twitter_complaints_input_output.json b/tests/artifacts/testdata/twitter_complaints_input_output.json similarity index 100% rename from tests/data/twitter_complaints_input_output.json rename to tests/artifacts/testdata/twitter_complaints_input_output.json diff --git a/tests/data/twitter_complaints_input_output.jsonl b/tests/artifacts/testdata/twitter_complaints_input_output.jsonl similarity index 100% rename from tests/data/twitter_complaints_input_output.jsonl rename to tests/artifacts/testdata/twitter_complaints_input_output.jsonl diff --git a/tests/data/twitter_complaints_small.json b/tests/artifacts/testdata/twitter_complaints_small.json similarity index 100% rename from tests/data/twitter_complaints_small.json rename to tests/artifacts/testdata/twitter_complaints_small.json diff --git a/tests/data/twitter_complaints_small.jsonl b/tests/artifacts/testdata/twitter_complaints_small.jsonl similarity index 100% rename from tests/data/twitter_complaints_small.jsonl rename to tests/artifacts/testdata/twitter_complaints_small.jsonl diff --git a/tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json b/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json similarity index 100% rename from tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json rename to tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json diff --git a/tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl b/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl similarity index 100% rename from tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl rename to tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index e2c37950b..e331a5e9b 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -26,7 +26,7 @@ # First Party from build.accelerate_launch import main from build.utils import serialize_args, get_highest_checkpoint -from tests.data import TWITTER_COMPLAINTS_DATA_JSONL +from tests.artifacts.testdata import TWITTER_COMPLAINTS_DATA_JSONL from tuning.utils.error_logging import ( USER_ERROR_EXIT_CODE, INTERNAL_ERROR_EXIT_CODE, diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py new file mode 100644 index 000000000..d2a390fe9 --- /dev/null +++ b/tests/data/test_data_handlers.py @@ -0,0 +1,110 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Third Party +from transformers import AutoTokenizer +import datasets +import pytest + +# First Party +from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL + +# Local +from tuning.data.data_handlers import ( + apply_custom_data_formatting_template, + combine_sequence, +) + + +def test_apply_custom_formatting_template(): + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + formatted_dataset_field = "formatted_data_field" + formatted_dataset = json_dataset.map( + apply_custom_data_formatting_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + # First response from the data file that is read. + expected_response = ( + "### Input: @HMRCcustomers No this is my first job" + + " \n\n ### Response: no complaint" + + tokenizer.eos_token + ) + + # a new dataset_text_field is created in Dataset + assert formatted_dataset_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response + + +def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): + """Tests that the formatting function will throw error if wrong keys are passed to template""" + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" + formatted_dataset_field = "formatted_data_field" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises(KeyError): + json_dataset.map( + apply_custom_data_formatting_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + + +@pytest.mark.parametrize( + "input_element,output_element,expected_res", + [ + ("foo ", "bar", "foo bar"), + ("foo\n", "bar", "foo\nbar"), + ("foo\t", "bar", "foo\tbar"), + ("foo", "bar", "foo bar"), + ], +) +def test_combine_sequence(input_element, output_element, expected_res): + """Ensure that input / output elements are combined with correct whitespace handling.""" + comb_seq = combine_sequence(input_element, output_element) + assert isinstance(comb_seq, str) + assert comb_seq == expected_res + + +@pytest.mark.parametrize( + "input_element,output_element,expected_res", + [ + ("foo ", "bar", "foo bar"), + ("foo\n", "bar", "foo\nbar"), + ("foo\t", "bar", "foo\tbar"), + ("foo", "bar", "foo bar"), + ], +) +def test_combine_sequence_adds_eos(input_element, output_element, expected_res): + """Ensure that input / output elements are combined with correct whitespace handling.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) + expected_res += tokenizer.eos_token + assert isinstance(comb_seq, str) + assert comb_seq == expected_res diff --git a/tests/utils/test_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py similarity index 50% rename from tests/utils/test_preprocessing_utils.py rename to tests/data/test_data_preprocessing_utils.py index 2fbbc38ff..02308b2f5 100644 --- a/tests/utils/test_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -1,13 +1,36 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import json +import tempfile + # Third Party from datasets import Dataset -from datasets.exceptions import DatasetGenerationError from transformers import AutoTokenizer, DataCollatorForSeq2Seq from trl import DataCollatorForCompletionOnlyLM +import datasets import pytest +import yaml # First Party -from tests.data import ( - MALFORMATTED_DATA, +from tests.artifacts.predefined_data_configs import ( + APPLY_CUSTOM_TEMPLATE_YAML, + PRETOKENIZE_JSON_DATA_YAML, + TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, +) +from tests.artifacts.testdata import ( MODEL_NAME, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, @@ -19,137 +42,149 @@ # Local from tuning.config import configs -from tuning.utils.preprocessing_utils import ( - combine_sequence, - format_dataset, - get_data_collator, - get_formatted_dataset_with_single_sequence, - get_preprocessed_dataset, +from tuning.data.data_config import DataPreProcessorConfig, DataSetConfig +from tuning.data.data_preprocessing_utils import get_data_collator +from tuning.data.data_processors import DataPreProcessor, get_datapreprocessor +from tuning.data.setup_dataprocessor import ( + _process_dataconfig_file, is_pretokenized_dataset, - load_hf_dataset_from_file, - validate_data_args, + process_dataargs, ) @pytest.mark.parametrize( - "input_element,output_element,expected_res", + "datafile, column_names", [ - ("foo ", "bar", "foo bar"), - ("foo\n", "bar", "foo\nbar"), - ("foo\t", "bar", "foo\tbar"), - ("foo", "bar", "foo bar"), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + set(["ID", "Label", "input", "output"]), + ), + ( + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + set( + [ + "Tweet text", + "ID", + "Label", + "text_label", + "output", + "input_ids", + "labels", + ] + ), + ), + ( + TWITTER_COMPLAINTS_DATA_JSONL, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + ), ], ) -def test_combine_sequence(input_element, output_element, expected_res): - """Ensure that input / output elements are combined with correct whitespace handling.""" - comb_seq = combine_sequence(input_element, output_element) - assert isinstance(comb_seq, str) - assert comb_seq == expected_res +def test_load_dataset_with_datafile(datafile, column_names): + """Ensure that both dataset is loaded with datafile.""" + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + load_dataset = processor.load_dataset( + datasetconfig=None, splitName="train", datafile=datafile + ) + assert set(load_dataset.column_names) == column_names @pytest.mark.parametrize( - "input_element,output_element,expected_res", + "datafile, column_names, datasetconfigname", [ - ("foo ", "bar", "foo bar"), - ("foo\n", "bar", "foo\nbar"), - ("foo\t", "bar", "foo\tbar"), - ("foo", "bar", "foo bar"), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + set(["ID", "Label", "input", "output"]), + "text_dataset_input_output_masking", + ), + ( + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + set( + [ + "Tweet text", + "ID", + "Label", + "text_label", + "output", + "input_ids", + "labels", + ] + ), + "pretokenized_dataset", + ), + ( + TWITTER_COMPLAINTS_DATA_JSONL, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + "apply_custom_data_template", + ), ], ) -def test_combine_sequence_adds_eos(input_element, output_element, expected_res): - """Ensure that input / output elements are combined with correct whitespace handling.""" - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) - expected_res += tokenizer.eos_token - assert isinstance(comb_seq, str) - assert comb_seq == expected_res +def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigname): + """Ensure that both dataset is loaded with datafile.""" + datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafile]) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + load_dataset = processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + assert set(load_dataset.column_names) == column_names -# Tests for loading the dataset from disk @pytest.mark.parametrize( - "dataset_path", - [TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSON], + "datafile, datasetconfigname", + [ + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + "text_dataset_input_output_masking", + ), + (TWITTER_COMPLAINTS_TOKENIZED_JSONL, "pretokenized_dataset"), + (TWITTER_COMPLAINTS_DATA_JSONL, "apply_custom_data_template"), + ], ) -def test_load_hf_dataset_from_file(dataset_path): - input_field_name = "Tweet text" - output_field_name = "text_label" - data = load_hf_dataset_from_file( - dataset_path, - input_field_name=input_field_name, - output_field_name=output_field_name, +def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname): + """Ensure that both datasetconfig and datafile cannot be passed.""" + datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafile]) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None ) - # Our dataset should contain dicts that contain the input / output field name types - next_data = next(iter(data)) - assert input_field_name in next_data - assert output_field_name in next_data - - -def test_load_hf_dataset_from_jsonl_file_wrong_keys(): - """Ensure that we explode if the keys are not in the jsonl file.""" - with pytest.raises(DatasetGenerationError): - load_hf_dataset_from_file( - TWITTER_COMPLAINTS_DATA_JSONL, - input_field_name="foo", - output_field_name="bar", - ) - - -def test_load_hf_dataset_from_malformatted_data(): - """Ensure that we explode if the data is not properly formatted.""" - # NOTE: The actual keys don't matter here - with pytest.raises(DatasetGenerationError): - load_hf_dataset_from_file( - MALFORMATTED_DATA, input_field_name="foo", output_field_name="bar" + with pytest.raises(ValueError): + processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=datafile ) -def test_load_hf_dataset_from_jsonl_file_duplicate_keys(): - """Ensure we cannot have the same key for input / output.""" +def test_load_dataset_without_dataconfig_and_datafile(): + """Ensure that both datasetconfig and datafile cannot be None.""" + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) with pytest.raises(ValueError): - load_hf_dataset_from_file( - TWITTER_COMPLAINTS_DATA_JSONL, - input_field_name="Tweet text", - output_field_name="Tweet text", - ) + processor.load_dataset(datasetconfig=None, splitName="train", datafile=None) -# Tests for custom masking / preprocessing logic @pytest.mark.parametrize( - "dataset_path, max_sequence_length", + "data, result", [ - (TWITTER_COMPLAINTS_DATA_JSONL, 1), - (TWITTER_COMPLAINTS_DATA_JSONL, 10), - (TWITTER_COMPLAINTS_DATA_JSONL, 100), - (TWITTER_COMPLAINTS_DATA_JSONL, 1000), - (TWITTER_COMPLAINTS_DATA_JSON, 1), - (TWITTER_COMPLAINTS_DATA_JSON, 10), - (TWITTER_COMPLAINTS_DATA_JSON, 100), - (TWITTER_COMPLAINTS_DATA_JSON, 1000), + (TWITTER_COMPLAINTS_DATA_JSONL, False), + ( + Dataset.from_list( + [ + { + "input_ids": [9437, 29, 210], + "attention_mask": [1, 1, 1], + "labels": [1, 20, 30], + } + ] + ), + True, + ), ], ) -def test_get_preprocessed_dataset(dataset_path, max_sequence_length): - """Ensure we can handle preprocessed datasets with different max_sequence_lengths - to ensure proper tokenization and truncation. - """ - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - preprocessed_data = get_preprocessed_dataset( - data_path=dataset_path, - tokenizer=tokenizer, - max_sequence_length=max_sequence_length, - input_field_name="Tweet text", - output_field_name="text_label", - ) - for tok_res in preprocessed_data: - # Since the padding is left to the collator, there should be no 0s in the attention mask yet - assert sum(tok_res["attention_mask"]) == len(tok_res["attention_mask"]) - # If the source text isn't empty, we start with masked inputs - assert tok_res["labels"][0] == -100 - # All keys in the produced record must be the same length - key_lengths = {len(tok_res[k]) for k in tok_res.keys()} - assert len(key_lengths) == 1 - # And also that length should be less than or equal to the max length depending on if we - # are going up to / over the max size and truncating - padding is handled separately - assert key_lengths.pop() <= max_sequence_length +def test_is_pretokenized_data(data, result): + """Ensure that the correct collator type is fetched based on the data args""" + assert is_pretokenized_dataset(data=data) == result @pytest.mark.parametrize( @@ -158,10 +193,10 @@ def test_get_preprocessed_dataset(dataset_path, max_sequence_length): ( False, "\n### Label:", - load_hf_dataset_from_file( - TWITTER_COMPLAINTS_DATA_JSONL, - input_field_name="Tweet text", - output_field_name="text_label", + datasets.load_dataset( + "json", + data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + split="train", ), 1024, DataCollatorForCompletionOnlyLM, @@ -195,35 +230,12 @@ def test_get_data_collator( packing, response_template, AutoTokenizer.from_pretrained(MODEL_NAME), - formatted_train_dataset, + is_pretokenized_dataset(formatted_train_dataset), max_seq_length, ) assert isinstance(collator, expected_collator) -@pytest.mark.parametrize( - "data, result", - [ - (TWITTER_COMPLAINTS_DATA_JSONL, False), - ( - Dataset.from_list( - [ - { - "input_ids": [9437, 29, 210], - "attention_mask": [1, 1, 1], - "labels": [1, 20, 30], - } - ] - ), - True, - ), - ], -) -def test_is_pretokenized_data(data, result): - """Ensure that the correct collator type is fetched based on the data args""" - assert is_pretokenized_dataset(data=data) == result - - # Tests for validating data args # Invalid args return ValueError @pytest.mark.parametrize( @@ -310,63 +322,75 @@ def test_is_pretokenized_data(data, result): ), ], ) -def test_validate_args(data_args, packing): +def test_process_data_args_throws_error_where_needed(data_args, packing): """Ensure that respective errors are thrown for incorrect data arguments""" with pytest.raises(ValueError): - validate_data_args(data_args, packing) - - -@pytest.mark.parametrize( - "data_args, packing", - [ - # pretokenized train dataset and no validation dataset passed - ( - configs.DataArguments( - training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - ), - False, - ), - # pretokenized train and validation datasets - ( - configs.DataArguments( - training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - ), - False, - ), - ], -) -def test_validate_args_pretokenized(data_args, packing): - """Ensure that supported data args do not error out when passing pretokenized datasets""" - validate_data_args(data_args, packing) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + TRAIN_ARGS = configs.TrainingArguments( + packing=packing, + max_seq_length=1024, + output_dir="tmp", # Not needed but positional + ) + (_, _, _, _, _, _) = process_dataargs(data_args, tokenizer, TRAIN_ARGS) @pytest.mark.parametrize( - "data_path, dataset_text_field, data_formatter_template", + "data_config_path, data_path", [ - (TWITTER_COMPLAINTS_DATA_JSON, "output", None), - (TWITTER_COMPLAINTS_DATA_JSONL, "output", None), + (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), + (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), + (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), + (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), ( + TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, - "formatted_field", - "### Text:{{input}} \n\n### Label: {{output}}", ), ( + TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, - "formatted_field", - "### Text:{{input}} \n\n### Label: {{output}}", ), ], ) -def test_get_formatted_dataset_with_single_sequence( - data_path, dataset_text_field, data_formatter_template -): +def test_process_dataconfig_file(data_config_path, data_path): + """Ensure that datasets are formatted and validated correctly based on the arguments passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"][0] = data_path + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - formatted_dataset = get_formatted_dataset_with_single_sequence( - data_path, dataset_text_field, tokenizer, data_formatter_template - ) - assert isinstance(formatted_dataset, Dataset) - assert dataset_text_field in formatted_dataset.column_names + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + if datasets_name == "text_dataset_input_output_masking": + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(train_set.column_names) == column_names + elif datasets_name == "pretokenized_dataset": + assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) + elif datasets_name == "apply_custom_data_template": + assert formatted_dataset_field in set(train_set.column_names) @pytest.mark.parametrize( @@ -395,8 +419,8 @@ def test_get_formatted_dataset_with_single_sequence( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, - dataset_text_field="formatted_field", data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", + response_template="\n### Label:", ) ), # data formatter template with input/output JSONL @@ -404,8 +428,8 @@ def test_get_formatted_dataset_with_single_sequence( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, - dataset_text_field="formatted_field", data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", + response_template="\n### Label:", ) ), # input/output JSON with masking on input @@ -424,11 +448,16 @@ def test_get_formatted_dataset_with_single_sequence( ), ], ) -def test_format_dataset(data_args): +def test_process_dataargs(data_args): """Ensure that the train/eval data are properly formatted based on the data args / text field""" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - train_set, eval_set, dataset_text_field = format_dataset( - data_args, tokenizer, max_seq_length=1024 + TRAIN_ARGS = configs.TrainingArguments( + packing=False, + max_seq_length=1024, + output_dir="tmp", # Not needed but positional + ) + (train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs( + data_args, tokenizer, TRAIN_ARGS ) assert isinstance(train_set, Dataset) assert isinstance(eval_set, Dataset) @@ -472,9 +501,17 @@ def test_format_dataset(data_args): ), ], ) -def test_format_dataset_pretokenized(data_args): +def test_process_dataargs_pretokenized(data_args): """Ensure that pretokenized datasets are loaded and returned as is""" - train_set, eval_set, _ = format_dataset(data_args, None, max_seq_length=1024) + TRAIN_ARGS = configs.TrainingArguments( + packing=False, + max_seq_length=1024, + output_dir="tmp", # Not needed but positional + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, eval_set, _, _, _, _) = process_dataargs( + data_args, tokenizer, TRAIN_ARGS + ) assert isinstance(train_set, Dataset) if eval_set: assert isinstance(eval_set, Dataset) @@ -482,3 +519,52 @@ def test_format_dataset_pretokenized(data_args): assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) if eval_set: assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names)) + + +@pytest.mark.parametrize( + "datafile, column_names, datasetconfigname", + [ + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + set(["ID", "Label", "input", "output"]), + "text_dataset_input_output_masking", + ), + ( + TWITTER_COMPLAINTS_TOKENIZED_JSON, + set( + [ + "Tweet text", + "ID", + "Label", + "text_label", + "output", + "input_ids", + "labels", + ] + ), + "pretokenized_dataset", + ), + ( + TWITTER_COMPLAINTS_DATA_JSON, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + "apply_custom_data_template", + ), + ], +) +def test_process_dataset_configs(datafile, column_names, datasetconfigname): + """Test process_dataset_configs for expected output.""" + dataprocessor_config = DataPreProcessorConfig() + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + processor = DataPreProcessor( + processor_config=dataprocessor_config, + tokenizer=tokenizer, + ) + datasetconfig = [DataSetConfig(name=datasetconfigname, data_paths=[datafile])] + train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig) + + assert isinstance(train_dataset, Dataset) + assert set(train_dataset.column_names) == column_names + + with open(datafile, "r") as file: + data = json.load(file) + assert len(train_dataset) == len(data) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index a800ed6f6..69ccbf4fa 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -31,7 +31,7 @@ # First Party from build.utils import serialize_args from scripts.run_inference import TunedCausalLM -from tests.data import ( +from tests.artifacts.testdata import ( EMPTY_DATA, MALFORMATTED_DATA, MODEL_NAME, @@ -300,7 +300,7 @@ def test_run_train_fails_training_data_path_not_exist(): """Check fails when data path not found.""" updated_data_path_args = copy.deepcopy(DATA_ARGS) updated_data_path_args.training_data_path = "fake/path" - with pytest.raises(FileNotFoundError): + with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, updated_data_path_args, TRAIN_ARGS, None) @@ -906,15 +906,12 @@ def test_empty_data(): def test_data_path_is_a_directory(): - """Ensure that we get FileNotFoundError if we point the data path at a dir, not a file.""" + """Ensure that we get ValueError if we point the data path at a dir, not a file.""" with tempfile.TemporaryDirectory() as tempdir: data_args = copy.deepcopy(DATA_ARGS) data_args.training_data_path = tempdir - # Confusingly, if we pass a directory for our data path, it will throw a - # FileNotFoundError saying "unable to find ''", since it can't - # find a matchable file in the path. - with pytest.raises(FileNotFoundError): + with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py index ba1a05808..2326e8e8c 100644 --- a/tests/trainercontroller/test_tuning_trainercontroller.py +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -30,7 +30,7 @@ from tests.trainercontroller.custom_operation_invalid_action import ( CustomOperationInvalidAction, ) -import tests.data.trainercontroller as td +import tests.artifacts.testdata.trainercontroller as td # Local import tuning.config.configs as config diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py deleted file mode 100644 index f89736657..000000000 --- a/tests/utils/test_data_utils.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright The FMS HF Tuning Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# SPDX-License-Identifier: Apache-2.0 -# https://spdx.dev/learn/handling-license-info/ - -# Third Party -import datasets -import pytest - -# First Party -from tests.data import TWITTER_COMPLAINTS_DATA_JSONL - -# Local -from tuning.utils import data_utils - - -def test_apply_custom_formatting_template(): - json_dataset = datasets.load_dataset( - "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL - ) - template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" - # First response from the data file that is read. - expected_response = ( - "### Input: @HMRCcustomers No this is my first job" - + " \n\n ### Response: no complaint" - ) - formatted_dataset_field = "formatted_data_field" - formatted_dataset = data_utils.apply_custom_formatting_template( - json_dataset, template, formatted_dataset_field - ) - # a new dataset_text_field is created in Dataset - assert formatted_dataset_field in formatted_dataset["train"][0] - assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response - - -def test_apply_custom_formatting_template_adds_eos_token(): - json_dataset = datasets.load_dataset( - "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL - ) - template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" - # First response from the data file that is read. - expected_response = ( - "### Input: @HMRCcustomers No this is my first job" - + " \n\n ### Response: no complaintEOS" - ) - formatted_dataset_field = "formatted_data_field" - formatted_dataset = data_utils.apply_custom_formatting_template( - json_dataset, template, formatted_dataset_field, "EOS" - ) - # a new dataset_text_field is created in Dataset - assert formatted_dataset_field in formatted_dataset["train"][0] - assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response - - -def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): - """Tests that the formatting function will throw error if wrong keys are passed to template""" - json_dataset = datasets.load_dataset( - "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL - ) - template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" - formatted_dataset_field = "formatted_data_field" - with pytest.raises(KeyError): - data_utils.apply_custom_formatting_template( - json_dataset, template, formatted_dataset_field, "EOS" - ) diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 7b7aa1a2a..88a38c839 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -28,33 +28,32 @@ from tuning.utils.logging import set_log_level -@mock.patch.dict(os.environ, {}, clear=True) def test_set_log_level_for_logger_default(): """ Ensure that the correct log level is being set for python native logger and transformers logger when no env var or CLI flag is passed """ - train_args = copy.deepcopy(TRAIN_ARGS) - training_args, logger = set_log_level(train_args) - assert logger.getEffectiveLevel() == logging.WARNING - assert training_args.log_level == "passive" + with mock.patch.dict(os.environ, {}, clear=True): + train_args = copy.deepcopy(TRAIN_ARGS) + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.WARNING + assert training_args.log_level == "passive" -@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True) def test_set_log_level_for_logger_with_env_var(): """ Ensure that the correct log level is being set for python native logger and transformers logger when env var LOG_LEVEL is used """ - train_args_env = copy.deepcopy(TRAIN_ARGS) - training_args, logger = set_log_level(train_args_env) - assert logger.getEffectiveLevel() == logging.INFO - assert training_args.log_level == "info" + with mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True): + train_args_env = copy.deepcopy(TRAIN_ARGS) + training_args, logger = set_log_level(train_args_env) + assert logger.getEffectiveLevel() == logging.INFO + assert training_args.log_level == "info" -@mock.patch.dict(os.environ, {"TRANSFORMERS_VERBOSITY": "info"}, clear=True) def test_set_log_level_for_logger_with_set_verbosity_and_cli(): """ Ensure that the correct log level is being set for python native logger and @@ -62,14 +61,14 @@ def test_set_log_level_for_logger_with_set_verbosity_and_cli(): and CLI flag is passed """ - train_args = copy.deepcopy(TRAIN_ARGS) - train_args.log_level = "error" - training_args, logger = set_log_level(train_args) - assert logger.getEffectiveLevel() == logging.ERROR - assert training_args.log_level == "error" + with mock.patch.dict(os.environ, {"TRANSFORMERS_VERBOSITY": "info"}, clear=True): + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.log_level = "error" + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.ERROR + assert training_args.log_level == "error" -@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True) def test_set_log_level_for_logger_with_env_var_and_cli(): """ Ensure that the correct log level is being set for python native logger and @@ -77,8 +76,9 @@ def test_set_log_level_for_logger_with_env_var_and_cli(): In this case, CLI arg takes precedence over the set env var LOG_LEVEL. """ - train_args = copy.deepcopy(TRAIN_ARGS) - train_args.log_level = "error" - training_args, logger = set_log_level(train_args) - assert logger.getEffectiveLevel() == logging.ERROR - assert training_args.log_level == "error" + with mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True): + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.log_level = "error" + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.ERROR + assert training_args.log_level == "error" diff --git a/tests/utils/test_tokenizer_data_utils.py b/tests/utils/test_tokenizer_data_utils.py index 1afd34d4d..e24c90099 100644 --- a/tests/utils/test_tokenizer_data_utils.py +++ b/tests/utils/test_tokenizer_data_utils.py @@ -3,7 +3,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer # First Party -from tests.data import MODEL_NAME +from tests.artifacts.testdata import MODEL_NAME # Local # First party diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 4bff99f19..222bf4424 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -95,6 +95,13 @@ class DataArguments: or data_formatter_template needs to be supplied." }, ) + data_config_path: str = field( + default=None, + metadata={ + "help": "data config file which specifies the data preprocessing logic to apply.\ + Supports both JSON and YAML based config files." + }, + ) @dataclass diff --git a/tuning/data/__init__.py b/tuning/data/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tuning/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py new file mode 100644 index 000000000..7e3ccd83b --- /dev/null +++ b/tuning/data/data_config.py @@ -0,0 +1,134 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import dataclass +from typing import Dict, List, Optional +import logging +import os + +# Local +from tuning.utils.utils import load_yaml_or_json + + +@dataclass +class DataHandlerConfig: + name: str + arguments: Optional[Dict] + + +@dataclass +class DataSetConfig: + name: str + data_paths: List[str] + sampling: Optional[Dict] = None + data_handlers: Optional[List[DataHandlerConfig]] = None + + +@dataclass +class DataPreProcessorConfig: + type: Optional[str] = "default" + + +@dataclass +class DataConfig: + dataprocessor: DataPreProcessorConfig + datasets: List[DataSetConfig] + + +def _validate_data_handler_config(data_handler) -> DataHandlerConfig: + kwargs = data_handler + assert isinstance(kwargs, dict), "data_handlers in data_config needs to be a dict" + assert "name" in kwargs and isinstance( + kwargs["name"], str + ), "data_handlers need to have a name with type str" + assert "arguments" in kwargs, "data handlers need to have arguments" + assert isinstance( + kwargs["arguments"], dict + ), "data handler arguments should be of the type dict" + return DataHandlerConfig(**kwargs) + + +def _validate_dataset_config(dataset_config) -> DataSetConfig: + kwargs = dataset_config + assert isinstance(kwargs, dict), "dataset_config in data_config needs to be a dict" + + c = DataSetConfig(name=kwargs.get("name", ""), data_paths=[]) + + if "name" in kwargs: + assert isinstance(kwargs["name"], str), "dataset name should be string" + if "data_paths" not in kwargs: + raise ValueError("data_paths should be specified for each dataset") + data_paths = kwargs["data_paths"] + # TODO: Support that data_paths can be a directory or directories + assert isinstance(data_paths, list), "data_paths should be an array of files" + c.data_paths = [] + for p in data_paths: + assert isinstance(p, str), f"path {p} should be of the type string" + assert os.path.exists(p), f"data_paths {p} does not exist" + if not os.path.isabs(p): + _p = os.path.abspath(p) + logging.warning( + " Provided path %s is not absolute changing it to %s", p, _p + ) + p = _p + c.data_paths.append(p) + if "sampling" in kwargs: + sampling_kwargs = kwargs["sampling"] + assert isinstance( + dict, sampling_kwargs + ), "sampling arguments should be of the type dict" + if "ratio" in sampling_kwargs: + ratio = sampling_kwargs["ratio"] + assert isinstance(ratio, float) and ( + 0 <= ratio <= 1.0 + ), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]" + c.sampling = sampling_kwargs + if "data_handlers" in kwargs: + c.data_handlers = [] + for handler in kwargs["data_handlers"]: + c.data_handlers.append(_validate_data_handler_config(handler)) + return c + + +def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConfig: + kwargs = dataprocessor_config + c = DataPreProcessorConfig() + assert isinstance(kwargs, dict), "dataprocessor in data_config needs to be a dict" + return c + + +def validate_data_config(dataconfig: DataConfig): + _validate_dataprocessor_config(dataconfig.dataprocessor) + for d in dataconfig.datasets: + _validate_dataset_config(d) + + +def load_and_validate_data_config(data_config_file: str) -> DataConfig: + raw_data = load_yaml_or_json(data_config_file) + assert isinstance( + raw_data, dict + ), f"The provided data_config file is invalid: {data_config_file}" + assert "datasets" in raw_data, "datasets should be provided in data config" + assert isinstance( + raw_data["datasets"], list + ), "datasets should be provided as a list" + datasets = [] + for d in raw_data["datasets"]: + datasets.append(_validate_dataset_config(d)) + if "dataprocessor" in raw_data: + dataprocessor = _validate_dataprocessor_config(raw_data["dataprocessor"]) + + data_config = DataConfig(dataprocessor=dataprocessor, datasets=datasets) + return data_config diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py new file mode 100644 index 000000000..f0100072b --- /dev/null +++ b/tuning/data/data_handlers.py @@ -0,0 +1,142 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Definition of some predefined data preprocessing functions that we need. + +# Standard +from typing import Dict, List +import re + +# Third Party +from transformers import AutoTokenizer + + +### Utils for custom masking / manipulating input / output strs, etc +def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): + """Combines / concatenates input & output element. + + Args: + input_element: str + Input component of the combined sequence. + output_element: str + Output component of the combined sequence. + eos_token: str + EOS token associated with the tokenizer. \ + If passed, it will be concatenated at end + + Returns: + str + Sequence combined with whitespace. + """ + if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith( + (" ", "\n", "\t") + ): + return input_element + " " + output_element + eos_token + return input_element + output_element + eos_token + + +def tokenize_and_apply_input_masking( + element: Dict[str, str], + tokenizer: AutoTokenizer, + column_names: List[str], + input_field_name: str, + output_field_name: str, + **tokenizer_kwargs, +): + if (input_field_name or output_field_name) not in column_names: + raise ValueError( + f"Dataset should contain {input_field_name} \ + and {output_field_name} field if \ + no dataset_text_field or data_formatter_template specified" + ) + + input_text = element[input_field_name] + output_text = element[output_field_name] + + combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token) + + fn_kwargs = tokenizer_kwargs.get("fn_kwargs", {}) + tokenizer_inner_kwargs = fn_kwargs.get("tokenizer_kwargs", {}) + + tokenized_comb_seqs = tokenizer(combined, **tokenizer_inner_kwargs) + tokenized_input = tokenizer(input_text, **tokenizer_inner_kwargs) + + masked_labels = [-100] * len( + tokenized_input.input_ids + ) + tokenized_comb_seqs.input_ids[len(tokenized_input.input_ids) :] + + # Any benefit of retaining the old columns? + return { + "input_ids": tokenized_comb_seqs.input_ids, + "labels": masked_labels, + "attention_mask": tokenized_comb_seqs.attention_mask, + } + + +def apply_dataset_formatting( + element: Dict[str, str], + tokenizer: AutoTokenizer, + dataset_text_field: str, + **kwargs, +): + return { + f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token + } + + +def apply_custom_data_formatting_template( + element: Dict[str, str], + tokenizer: AutoTokenizer, + dataset_text_field: str, + template: str, + **kwargs, +): + """Function to format datasets with Alpaca style / other templates. + Expects to be run as a HF Map API function. + Args: + element: the HF Dataset element loaded from a JSON or DatasetDict object. + template: Template to format data with. Features of Dataset + should be referred to by {{key}} + formatted_dataset_field: Dataset_text_field + eos_token: string EOS token to be appended while formatting data to a single sequence. + Defaults to empty + Returns: + Formatted HF Dataset + """ + + template += tokenizer.eos_token + + def replace_text(match_obj): + captured_groups = match_obj.groups() + if len(captured_groups) != 1: + raise ValueError( + "Unexpectedly captured multiple groups in template formatting" + ) + + index_object = captured_groups[0] + if index_object not in element: + raise KeyError("Requested template string is not a valid key in dict") + + return element[index_object] + + return { + dataset_text_field: re.sub(r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template) + } + + +AVAILABLE_DATA_HANDLERS = { + "tokenize_and_apply_input_masking": tokenize_and_apply_input_masking, + "apply_dataset_formatting": apply_dataset_formatting, + "apply_custom_data_formatting_template": apply_custom_data_formatting_template, +} diff --git a/tuning/data/data_preprocessing_utils.py b/tuning/data/data_preprocessing_utils.py new file mode 100644 index 000000000..589e4c9ef --- /dev/null +++ b/tuning/data/data_preprocessing_utils.py @@ -0,0 +1,74 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Standard +from typing import Callable, Optional + +# Third Party +from transformers import AutoTokenizer, DataCollatorForSeq2Seq +from trl import DataCollatorForCompletionOnlyLM + +# Local +from tuning.config import configs + + +def get_data_collator( + packing: bool, + response_template: Optional[str], + tokenizer: AutoTokenizer, + is_traindata_tokenized: bool, + max_seq_length: int, +) -> Callable: + """Create and return the the appropriate collator type based on the configuration for packing, + response_template, and dataset_text_field. + + Args: + packing: bool + Whether or not we should apply packing or not. + response_template: Optional[str] + Response template to be used for formatting by TRL. + tokenizer: AutoTokenizer + Loaded tokenizer object to be used by the collator. + is_traindata_tokenized: bool + Whether train Dataset is tokenized or not + max_seq_length: int + Max sequence length expected + + Returns: + Callable + Callable collator to be leveraged by the trainer. + """ + + if not packing: + # TODO: near term - how response template ids are parsed out needs to be cleaned. + # The [2:] here applies if response template has \n prefix, it is needed to strip \n, + # otherwise template is not found. We will create issue to clean this out after we discuss + # data formats and collators we will support. + if response_template: + response_template_ids = tokenizer.encode( + response_template, add_special_tokens=False + )[2:] + return DataCollatorForCompletionOnlyLM( + response_template=response_template_ids, + tokenizer=tokenizer, + ignore_index=configs.IGNORE_INDEX, + ) + # Note that this automatically pads labels with -100 + # TODO check if this is sufficient for preprocessed + if is_traindata_tokenized: + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=True, max_length=max_seq_length + ) + raise ValueError( + "Could not pick a data collator. Please refer to supported data formats" + ) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py new file mode 100644 index 000000000..f6f3b0ec9 --- /dev/null +++ b/tuning/data/data_processors.py @@ -0,0 +1,213 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Dict, List, Union +import logging +import os + +# Third Party +from datasets import Dataset, DatasetDict, IterableDataset +from datasets.exceptions import DatasetNotFoundError +from transformers import AutoTokenizer +import datasets +import torch + +# Local +from tuning.data.data_config import DataConfig, DataPreProcessorConfig, DataSetConfig +from tuning.data.data_handlers import AVAILABLE_DATA_HANDLERS +from tuning.utils.utils import get_extension, get_loader_for_filepath + + +class DataPreProcessor: + + tokenizer = None + data_config: DataConfig = None + processor_config: DataPreProcessorConfig = None + registered_handlers: Dict[str, callable] = None + + def __init__( + self, processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer + ): + self.tokenizer = tokenizer + self.processor_config = processor_config + + # Initialize other objects + self.registered_handlers = {} + + def register_data_handler(self, name: str, func: callable): + self.registered_handlers[name] = func + + def load_dataset( + self, + datasetconfig: DataSetConfig, + splitName: str, + datafile: str = None, + **kwargs, + ): + + if datafile and datasetconfig: + raise ValueError("Both datafile and datasetconfig should not be set") + if (not datafile) and (not datasetconfig): + raise ValueError("Either datafile or datasetconfig must be set") + + if datafile: + files = [datafile] + loader = get_loader_for_filepath(file_path=datafile) + elif datasetconfig: + files = datasetconfig.data_paths + name = datasetconfig.name + # simple check to make sure all files are of same type. + extns = [get_extension(f) for f in files] + assert extns.count(extns[0]) == len( + extns + ), f"All files in the dataset {name} should have the same extension" + loader = get_loader_for_filepath(file_path=files[0]) + + if loader in (None, ""): + raise ValueError(f"data path is invalid [{', '.join(files)}]") + + try: + return datasets.load_dataset( + loader, + data_files=files, + split=splitName, + **kwargs, + ) + except DatasetNotFoundError as e: + raise e + except FileNotFoundError as e: + raise ValueError(f"data path is invalid [{', '.join(files)}]") from e + + def _process_dataset_configs( + self, dataset_configs: List[DataSetConfig], **extra_kwargs + ) -> Union[Dataset, IterableDataset]: + train_dataset = None + final_datasets = None + splitName = "train" # default + + logging.info("Starting DataPreProcessor...") + # Iterate over the multiple datasets provided to us + for d in dataset_configs: + logging.info("Loading %s", d.name) + + # In future the streaming etc go as kwargs of this function + raw_dataset = self.load_dataset(d, splitName) + + logging.info("Loaded raw dataset : {raw_datasets}") + + raw_datasets = DatasetDict() + + # Assume all is train split + if isinstance(raw_dataset, Dataset): + raw_datasets[splitName] = raw_dataset + else: + raw_datasets = raw_dataset + + if d.sampling: + logging.warning("Sampling multiple datasets is not supported yet") + + if d.data_handlers: # Execute the datahandlers + for data_handler in d.data_handlers: + handler_name: str = data_handler.name + handler: callable = self.registered_handlers[handler_name] + kwargs: Dict = data_handler.arguments + + if "batched" not in kwargs: + kwargs["batched"] = False + + column_names = raw_datasets[splitName].column_names + + # remove __content__ from all processing + if "__content__" in column_names: + column_names.remove("__content__") + + if "remove_columns" not in kwargs: + kwargs["remove_columns"] = None + if kwargs["remove_columns"] == "all": + kwargs["remove_columns"] = column_names + + if "num_proc" not in kwargs: + kwargs["num_proc"] = os.cpu_count() + + if "fn_kwargs" not in kwargs: + kwargs["fn_kwargs"] = {} + + kwargs["fn_kwargs"]["tokenizer"] = self.tokenizer + kwargs["fn_kwargs"]["column_names"] = column_names + + kwargs["fn_kwargs"] = dict(kwargs["fn_kwargs"], **extra_kwargs) + + logging.info("Applying Handler: %s Args: %s", data_handler, kwargs) + + raw_datasets = raw_datasets.map(handler, **kwargs) + + if final_datasets is None: + final_datasets = raw_datasets + else: + for k in raw_datasets.keys(): + if k in final_datasets: + final_datasets[k] = datasets.concatenate_datasets( + [final_datasets[k], raw_datasets[k]] + ) + else: + final_datasets[k] = raw_datasets[k] + + if "train" in final_datasets: + train_dataset = final_datasets["train"] + + return train_dataset + + def process_dataset_configs( + self, dataset_configs: List[DataSetConfig], **kwargs + ) -> Union[Dataset, IterableDataset]: + train_dataset = None + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + logging.info("Processing data on rank 0...") + train_dataset = self._process_dataset_configs(dataset_configs, **kwargs) + else: + train_dataset = None + + # Use broadcast_object_list to share the dataset object across ranks + # TODO: Check if torch.distributed.barrier() is called in broadcast_object_list() + # See https://github.com/pytorch/pytorch/issues/56142 + # for why the list is shared like this + to_share = [train_dataset] + torch.distributed.broadcast_object_list(to_share, src=0) + train_dataset = to_share[0] + else: + logging.info("Processing data...") + train_dataset = self._process_dataset_configs(dataset_configs, **kwargs) + + return train_dataset + + +def autoregister_available_handlers(processor: DataPreProcessor): + if processor is None: + return + for name, func in AVAILABLE_DATA_HANDLERS.items(): + processor.register_data_handler(name=name, func=func) + + +def get_datapreprocessor( + processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer +) -> DataPreProcessor: + processor = DataPreProcessor( + processor_config=processor_config, + tokenizer=tokenizer, + ) + autoregister_available_handlers(processor) + return processor diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py new file mode 100644 index 000000000..5db8e0aee --- /dev/null +++ b/tuning/data/setup_dataprocessor.py @@ -0,0 +1,322 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Union +import logging + +# Third Party +from datasets import Dataset, IterableDataset + +# Third +from transformers import AutoTokenizer + +# Local +from tuning.config.configs import DataArguments, TrainingArguments +from tuning.data.data_config import ( + DataHandlerConfig, + DataPreProcessorConfig, + DataSetConfig, + load_and_validate_data_config, +) +from tuning.data.data_preprocessing_utils import get_data_collator +from tuning.data.data_processors import get_datapreprocessor + +# In future we may make the fields configurable +DEFAULT_JSON_INPUT_KEY = "input" +DEFAULT_JSON_OUTPUT_KEY = "output" + +# check if the provided dataset is pretokenized or not +# the check is taken from trl +# https://github.com/huggingface/trl/blob/ddf4c8dc3ecf6d9ee2b24f94c62182ffd682c808/trl/trainer/sft_trainer.py#L498-L509 +def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]): + if not data: + return False + if isinstance(data, str): + # Create a data processor with default processor config + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + data = processor.load_dataset(None, splitName="train[:1]", datafile=data) + + return ("input_ids" in data.column_names) and ("labels" in data.column_names) + + +# TODO: For now assume only training dataset is passed via data config file. +# This is very limited but is done to keep first implementation minimal +def _process_dataconfig_file(data_args: DataArguments, tokenizer: AutoTokenizer): + data_config = load_and_validate_data_config(data_args.data_config_path) + processor = get_datapreprocessor( + processor_config=data_config.dataprocessor, tokenizer=tokenizer + ) + train_dataset = processor.process_dataset_configs(data_config.datasets) + + return (train_dataset, None, data_args.dataset_text_field) + + +# Data Format 1: Pretokenized Data +def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized): + + # if the provided train dataset is pretokenized + # however user provides formatting flags, error out + if ( + data_args.response_template + or data_args.data_formatter_template + or data_args.dataset_text_field + ): + raise ValueError( + "fields response_template, data_formatter_template, and dataset_text_field \ + are not applicable for pretokenized datasets" + ) + + # if the train dataset is pretokenized + # ensure validation dataset is pretokenized otherwise error out + if is_eval_tokenized: + raise ValueError( + "validation data should be pretokenized to be used \ + along with pretokenized train data" + ) + + # Support for packing pretokenized datasets has been merged in trl library + # see: https://github.com/huggingface/trl/pull/2011 + # but we wait till a new transformers version is released to remove this check. + if packing: + raise ValueError("packing will not be used when datasets are pretokenized") + + # We do not need a handler here as this is tokenized dataset + return [], None + + +### Data format 2 +def _get_dataset_formatting_handlers(data_args, packing): + + if data_args.response_template is None: + if packing is False: + raise ValueError( + "Since dataset_text_field or data_formatter_template \ + is provided and packing is disabled, \ + needs a corresponding response template for masking" + ) + + if data_args.response_template: + # To use Response template, pass datasets with single sequence instances \ + # or a formatter template to create single sequence on the fly. + if not (data_args.dataset_text_field or data_args.data_formatter_template): + raise ValueError( + "dataset_text_field and data_formatter_template are None. \ + One of them needs to be set to use response_template" + ) + # Only one of dataset_text_field or data_formatter_template should be set. + if data_args.dataset_text_field and data_args.data_formatter_template: + raise ValueError( + "dataset_text_field and data_formatter_template are both set,\ + but are mutually exclusive options" + ) + + fn_kwargs = {} + dataset_text_field = data_args.dataset_text_field + + if dataset_text_field is None: + dataset_text_field = "new_formatted_field" + + fn_kwargs["dataset_text_field"] = dataset_text_field + if data_args.data_formatter_template is None: + handler = DataHandlerConfig( + "apply_dataset_formatting", + arguments={"fn_kwargs": fn_kwargs, "batched": False}, + ) + else: + fn_kwargs["template"] = data_args.data_formatter_template + handler = DataHandlerConfig( + "apply_custom_data_formatting_template", + arguments={"fn_kwargs": fn_kwargs, "batched": False}, + ) + return [handler], dataset_text_field + + +### Data format 3 +def _get_default_json_dataset_handlers(data_args, tokenizer_kwargs): + + fn_kwargs = {} + fn_kwargs["input_field_name"] = DEFAULT_JSON_INPUT_KEY + fn_kwargs["output_field_name"] = DEFAULT_JSON_OUTPUT_KEY + fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs + + kwargs = { + "fn_kwargs": fn_kwargs, + "batched": False, + "remove_columns": "all", + } + + handler = DataHandlerConfig("tokenize_and_apply_input_masking", arguments=kwargs) + return [handler], data_args.dataset_text_field + + +# Process raw dataargs for various usecases. +# Data Format 1: Pretokenized Data +# Use pretokenized data as-is without preprocessing. +# No handlers are needed for this format. +# Data Format 2: Single Sequence Dataset +# If a text field is specified, append the tokenizer's EOS token to it. +# If a formatter template is provided, apply it and save the result. +# Data remains un-tokenized. +# Data Format 3: JSON Dataset with Input/Output Fields +# Combine input and output fields, tokenize the data, and apply input attention masking. +# Requires both input and output fields; throws an error if missing. +def _process_raw_data_args( + data_args: DataArguments, + tokenizer: AutoTokenizer, + packing: bool, + max_seq_length: int, +): + + # Create a data processor with default processor config + default_processor_config = DataPreProcessorConfig() + data_processor = get_datapreprocessor( + processor_config=default_processor_config, tokenizer=tokenizer + ) + + assert isinstance( + data_args.training_data_path, str + ), "Training data path has to be set and str" + + is_eval_dataset_present = False + if data_args.validation_data_path: + is_eval_dataset_present = True + + # TODO: This check loads first slice of the dataset to view its columns + # Since this load is not done via processor it is redundant + is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path) + is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path) + + train_dataset_config = DataSetConfig( + name="training_data", + data_paths=[data_args.training_data_path], + data_handlers=None, + ) + if is_eval_dataset_present: + eval_dataset_config = DataSetConfig( + name="validation_data", + data_paths=[data_args.validation_data_path], + data_handlers=None, + ) + + # Setup some tokenizer kwargs for when we need a tokenizer + # TODO: Figure out a way to not hardcode this. + tokenizer_kwargs = {} + tokenizer_kwargs["max_length"] = max_seq_length + tokenizer_kwargs["truncation"] = True + tokenizer_kwargs["padding"] = False + + handlers = None + dataset_text_field = None + if is_traindata_tokenized: + # Data Format 1: Pretokenized Data + handlers, dataset_text_field = _get_pretokenized_dataset_handlers( + data_args, packing, (is_eval_dataset_present and not is_evaldata_tokenized) + ) + elif data_args.data_formatter_template or data_args.dataset_text_field: + # Data Format 2: Single Sequence Dataset + handlers, dataset_text_field = _get_dataset_formatting_handlers( + data_args, packing + ) + else: + # Data Format 3: JSON Dataset with Input/Output Fields + handlers, dataset_text_field = _get_default_json_dataset_handlers( + data_args, tokenizer_kwargs + ) + + # Now set handlers in the dataset configs + train_dataset_config.data_handlers = handlers + if is_eval_dataset_present: + eval_dataset_config.data_handlers = handlers + + # And let processor handle the logic + train_dataset = data_processor.process_dataset_configs([train_dataset_config]) + + eval_dataset = None + if is_eval_dataset_present: + eval_dataset = data_processor.process_dataset_configs([eval_dataset_config]) + + return (train_dataset, eval_dataset, dataset_text_field) + + +# If a data config file is provided, load it to get the training dataset. +# - Assumes only the training dataset is specified in the config file. +# - Expects a complete and valid data config file from the user. +# +# If no data config file is specified, process the remaining data arguments +# to determine the use case based on their presence, as explained in _process_raw_data_args. +def process_dataargs( + data_args: DataArguments, tokenizer: AutoTokenizer, train_args: TrainingArguments +): + """ + Args: + data_args: tuning.config.configs.DataArguments + tokenizer: AutoTokenizer + train_args: TrainingArguments + Training arguments passed to the library + Used for packing and max_seq_length + Returns: + Tuple(Dataset, Dataset, str, DataCollator, int, Dict) + tuple containing train_dataset, eval_dataset, dataset_text_field, + data_collator, max_seq_length and dataset_kwargs + + """ + + max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) + logging.info("Max sequence length is %s", max_seq_length) + if train_args.max_seq_length > tokenizer.model_max_length: + logging.warning( + "max_seq_length %s exceeds tokenizer.model_max_length \ + %s, using tokenizer.model_max_length %s", + train_args.max_seq_length, + tokenizer.model_max_length, + tokenizer.model_max_length, + ) + + train_dataset = eval_dataset = dataset_text_field = None + + if data_args.data_config_path: + train_dataset, eval_dataset, dataset_text_field = _process_dataconfig_file( + data_args, tokenizer + ) + else: + train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args( + data_args, tokenizer, train_args.packing, max_seq_length + ) + + data_collator = get_data_collator( + train_args.packing, + data_args.response_template, + tokenizer, + # Note: This check should not be removed. + # Its important to recompute this post handling to + # check if we already tokenized the dataset or not. + is_pretokenized_dataset(train_dataset), + max_seq_length, + ) + + dataset_kwargs = {} + if is_pretokenized_dataset(train_dataset or eval_dataset): + dataset_kwargs["skip_prepare_dataset"] = True + + return ( + train_dataset, + eval_dataset, + dataset_text_field, + data_collator, + max_seq_length, + dataset_kwargs, + ) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index fa7d0875c..c02d73781 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -53,6 +53,7 @@ FileLoggingTrackerConfig, TrackerConfigFactory, ) +from tuning.data.setup_dataprocessor import process_dataargs from tuning.trackers.tracker_factory import FILE_LOGGING_TRACKER, get_tracker from tuning.trainercontroller import TrainerControllerCallback from tuning.utils.config_utils import get_hf_peft_config, get_json_config @@ -63,12 +64,6 @@ write_termination_log, ) from tuning.utils.logging import set_log_level -from tuning.utils.preprocessing_utils import ( - format_dataset, - get_data_collator, - is_pretokenized_dataset, - validate_data_args, -) from tuning.utils.tokenizer_data_utils import tokenizer_and_embedding_resize @@ -257,17 +252,6 @@ def train( elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)): special_tokens_dict["pad_token"] = "" - max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) - logger.info("Max sequence length is %s", max_seq_length) - if train_args.max_seq_length > tokenizer.model_max_length: - logger.warning( - "max_seq_length %s exceeds tokenizer.model_max_length \ - %s, using tokenizer.model_max_length %s", - train_args.max_seq_length, - tokenizer.model_max_length, - tokenizer.model_max_length, - ) - # add special tokens only when a custom tokenizer is not passed if not model_args.tokenizer_name_or_path: # TODO: we need to change this, perhaps follow what open instruct does? @@ -302,28 +286,20 @@ def train( ) # Configure the collator and validate args related to packing prior to formatting the dataset - if train_args.packing: - logger.info("Packing is set to True") - data_collator = None - packing = True - else: - logger.info("Packing is set to False") - packing = False - - # Validate if data args are set properly - validate_data_args(data_args, packing) + data_collator = None + logger.info("Packing is set to %s ", train_args.packing) + data_preprocessing_time = time.time() ( formatted_train_dataset, formatted_validation_dataset, dataset_text_field, - ) = format_dataset(data_args, tokenizer, max_seq_length) - data_collator = get_data_collator( - packing, - data_args.response_template, - tokenizer, - formatted_train_dataset, + data_collator, max_seq_length, + dataset_kwargs, + ) = process_dataargs(data_args, tokenizer, train_args) + additional_metrics["data_preprocessing_time"] = ( + time.time() - data_preprocessing_time ) if framework is not None and framework.requires_agumentation: @@ -348,17 +324,12 @@ def train( } training_args = SFTConfig(**transformer_kwargs) - dataset_kwargs = {} - if is_pretokenized_dataset( - data_args.training_data_path or data_args.validation_data_path - ): - dataset_kwargs["skip_prepare_dataset"] = True trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=formatted_train_dataset, eval_dataset=formatted_validation_dataset, - packing=packing, + packing=train_args.packing, data_collator=data_collator, dataset_text_field=dataset_text_field, args=training_args, diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py deleted file mode 100644 index db5ff0f0f..000000000 --- a/tuning/utils/data_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -# Standard -import re - - -def apply_custom_formatting_template( - dataset, template, formatted_dataset_field, eos_token="" -): - """Function to format datasets with Alpaca style / other templates. - Args: - dataset: the HF Dataset element loaded from a JSON or DatasetDict object. - template: Template to format data with. Features of Dataset - should be referred to by {{key}} - formatted_dataset_field: Dataset_text_field - eos_token: string EOS token to be appended while formatting data to a single sequence. - Defaults to empty - Returns: - Formatted HF Dataset - """ - - template += eos_token - - if not formatted_dataset_field: - raise ValueError( - "Unable to apply custom formatting because the formatted_dataset_field was not provided" - ) - - def formatter(element): - def replace_text(match_obj): - captured_groups = match_obj.groups() - if len(captured_groups) != 1: - raise ValueError( - "Unexpectedly captured multiple groups in template formatting" - ) - - index_object = captured_groups[0] - if index_object not in element: - raise KeyError("Requested template string is not a valid key in dict") - - return element[index_object] - - return { - formatted_dataset_field: re.sub( - r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template - ) - } - - return dataset.map(formatter) diff --git a/tuning/utils/preprocessing_utils.py b/tuning/utils/preprocessing_utils.py deleted file mode 100644 index a07e99a4e..000000000 --- a/tuning/utils/preprocessing_utils.py +++ /dev/null @@ -1,451 +0,0 @@ -# Copyright The FMS HF Tuning Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Standard -from typing import Any, Callable, Dict, Optional, Union -import json -import logging -import os - -# Third Party -from datasets import Dataset, IterableDataset -from datasets.exceptions import DatasetGenerationError -from transformers import AutoTokenizer, DataCollatorForSeq2Seq -from trl import DataCollatorForCompletionOnlyLM -import datasets - -# Local -from tuning.config import configs -from tuning.utils.data_utils import apply_custom_formatting_template - -# In future we may make the fields configurable -JSON_INPUT_KEY = "input" -JSON_OUTPUT_KEY = "output" - - -# check if the provided dataset is pretokenized or not -# the check is taken from trl -# https://github.com/huggingface/trl/blob/ddf4c8dc3ecf6d9ee2b24f94c62182ffd682c808/trl/trainer/sft_trainer.py#L498-L509 -def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]): - if not data: - return False - if isinstance(data, str): - try: - data = datasets.load_dataset("json", data_files=data, split="train[:1]") - except DatasetGenerationError as e: - raise DatasetGenerationError("failed to load the provided dataset") from e - - return ("input_ids" in data.column_names) and ("labels" in data.column_names) - - -def validate_data_args(data_args: configs.DataArguments, packing: bool): - - assert isinstance( - data_args.training_data_path, str - ), "Training data path has to be set and str" - - is_train_data_pretokenized = is_pretokenized_dataset(data_args.training_data_path) - is_eval_data_pretokenized = is_pretokenized_dataset(data_args.validation_data_path) - - ### Data format 1 - # if the provided train dataset is pretokenized - # however user provides formatting flags, error out - if is_train_data_pretokenized: - if ( - data_args.response_template - or data_args.data_formatter_template - or data_args.dataset_text_field - ): - raise ValueError( - "fields response_template, data_formatter_template, and dataset_text_field \ - are not applicable for pretokenized datasets" - ) - - # if the train dataset is pretokenized - # ensure validation dataset is pretokenized otherwise error out - if data_args.validation_data_path and not is_eval_data_pretokenized: - raise ValueError( - "validation data should be pretokenized to be used \ - along with pretokenized train data" - ) - - # packing wont be available for pretokenized datasets in trl library - # see: https://github.com/huggingface/trl/issues/1848 - if packing: - raise ValueError("packing will not be used when datasets are pretokenized") - return - - ### Data format 2 - # Dataset containing single sequence needs a response template for masking - if data_args.dataset_text_field or data_args.data_formatter_template: - if data_args.response_template is None: - if packing is False: - raise ValueError( - "Since dataset_text_field or data_formatter_template \ - is provided and packing is disabled, \ - needs a corresponding response template for masking" - ) - - if data_args.response_template: - # To use Response template, pass datasets with single sequence instances \ - # or a formatter template to create single sequence on the fly. - if not (data_args.dataset_text_field or data_args.data_formatter_template): - raise ValueError( - "dataset_text_field and data_formatter_template are None. \ - One of them needs to be set to use response_template" - ) - # Only one of dataset_text_field or data_formatter_template should be set. - if data_args.dataset_text_field and data_args.data_formatter_template: - raise ValueError( - "dataset_text_field and data_formatter_template are both set,\ - but are mutually exclusive options" - ) - - ### Data format 3 - # If not single sequence, JSON should contain input/output fields - if not (data_args.dataset_text_field or data_args.data_formatter_template): - json_dataset = datasets.load_dataset( - "json", data_files=data_args.training_data_path - ) - if JSON_INPUT_KEY not in json_dataset["train"].column_names: - raise ValueError( - "JSON should contain input field if no dataset_text_field or \ - data_formatter_template specified" - ) - if JSON_OUTPUT_KEY not in json_dataset["train"].column_names: - raise ValueError( - "JSON should contain output field if no dataset_text_field or \ - data_formatter_template specified" - ) - - -def get_data_collator( - packing: bool, - response_template: Optional[str], - tokenizer: AutoTokenizer, - formatted_train_dataset: Dataset, - max_seq_length: int, -) -> Callable: - """Create and return the the appropriate collator type based on the configuration for packing, - response_template, and dataset_text_field. - - Args: - packing: bool - Whether or not we should apply packing or not. - response_template: Optional[str] - Response template to be used for formatting by TRL. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - formatted_train_dataset: Dataset - Train Dataset formatted for tuning - max_seq_length: int - Max sequence length expected - - Returns: - Callable - Callable collator to be leveraged by the trainer. - """ - is_train_data_pretokenized = is_pretokenized_dataset(formatted_train_dataset) - - if not packing: - # TODO: near term - how response template ids are parsed out needs to be cleaned. - # The [2:] here applies if response template has \n prefix, it is needed to strip \n, - # otherwise template is not found. We will create issue to clean this out after we discuss - # data formats and collators we will support. - if response_template: - response_template_ids = tokenizer.encode( - response_template, add_special_tokens=False - )[2:] - return DataCollatorForCompletionOnlyLM( - response_template=response_template_ids, - tokenizer=tokenizer, - ignore_index=configs.IGNORE_INDEX, - ) - # Note that this automatically pads labels with -100 - # TODO check if this is sufficient for preprocessed - if is_train_data_pretokenized: - return DataCollatorForSeq2Seq( - tokenizer=tokenizer, padding=True, max_length=max_seq_length - ) - raise ValueError( - "Could not pick a data collator. Please refer to supported data formats" - ) - - -def format_dataset( - data_args: configs.DataArguments, tokenizer: AutoTokenizer, max_seq_length: int -): - """ - Args: - data_args: tuning.config.configs.DataArguments - tokenizer: AutoTokenizer - max_seq_length: int - Max sequence length expected - Returns: - Tuple(Dataset, Dataset, str) - tuple containing train_dataset, eval_dataset and dataset_text_field - """ - eval_dataset = None - is_train_data_pretokenized = is_pretokenized_dataset(data_args.training_data_path) - - if is_train_data_pretokenized: - train_dataset = datasets.load_dataset( - "json", data_files=data_args.training_data_path, split="train" - ) - if data_args.validation_data_path: - eval_dataset = datasets.load_dataset( - "json", data_files=data_args.validation_data_path, split="train" - ) - # dataset_text_field is irrelevant to pretokenized datasets - return train_dataset, eval_dataset, None - - dataset_text_field = data_args.dataset_text_field - if data_args.data_formatter_template or dataset_text_field: - if dataset_text_field is None: - dataset_text_field = "new_formatted_field" - train_dataset = get_formatted_dataset_with_single_sequence( - data_args.training_data_path, - dataset_text_field, - tokenizer, - data_args.data_formatter_template, - ) - logging.info("Training dataset length is %s", len(train_dataset)) - if data_args.validation_data_path: - (eval_dataset) = get_formatted_dataset_with_single_sequence( - data_args.validation_data_path, - dataset_text_field, - tokenizer, - data_args.data_formatter_template, - ) - logging.info("Validation dataset length is %s", len(eval_dataset)) - else: - # This is for JSON containing input/output fields - train_dataset = get_preprocessed_dataset( - data_args.training_data_path, - tokenizer, - max_seq_length, - input_field_name=JSON_INPUT_KEY, - output_field_name=JSON_OUTPUT_KEY, - ) - if data_args.validation_data_path: - eval_dataset = get_preprocessed_dataset( - data_args.validation_data_path, - tokenizer, - max_seq_length, - input_field_name=JSON_INPUT_KEY, - output_field_name=JSON_OUTPUT_KEY, - ) - - return train_dataset, eval_dataset, dataset_text_field - - -def get_formatted_dataset_with_single_sequence( - data_path: str, - dataset_text_field: str, - tokenizer: AutoTokenizer, - data_formatter_template: Optional[str] = None, -) -> Dataset: - """Applies formatting to the loaded dataset instance; does NOT pretokenize data. - - Args: - data_path: str - Path to the file to be loaded. - dataset_text_field: str - Dataset text field to be used for formatting. - If data_formatter_template specified, \ - this will be the new field creating single sequence. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - data_formatter_template: str - Template to apply to create single sequence and store it in dataset_text_field - - Returns: - Dataset - HF Dataset with formatted [str] data. - """ - - json_dataset = datasets.load_dataset("json", data_files=data_path) - format_dataset_EOS = ( - lambda example: { # pylint: disable=unnecessary-lambda-assignment - f"{dataset_text_field}": example[f"{dataset_text_field}"] - + tokenizer.eos_token - } - ) - if data_formatter_template: - formatted_train_dataset = apply_custom_formatting_template( - json_dataset["train"], - data_formatter_template, - dataset_text_field, - tokenizer.eos_token, - ) - else: - formatted_train_dataset = json_dataset.map(format_dataset_EOS)[ - "train" - ] # HACK - for now, we just do both datasets separately; train is the default split - return formatted_train_dataset - - -def get_preprocessed_dataset( - data_path: str, - tokenizer: AutoTokenizer, - max_sequence_length: int, - input_field_name: str, - output_field_name: str, -) -> Dataset: - """Loads the dataset and applies the tokenizer + custom masking logic. - - Args: - data_path: str - Path to the file to be loaded. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - max_sequence_length: int - Max sequence length to be used for sequence tokenization. - input_field_name: str - Name of the input field in the data. - output_field_name: str - Name of the output field in the data. - - Returns: - Dataset - HF Dataset with the pretokenized data. - """ - dataset = load_hf_dataset_from_file(data_path, input_field_name, output_field_name) - return dataset.map( - preprocess_and_tokenize, - fn_kwargs={ - "tokenizer": tokenizer, - "max_seq_length": max_sequence_length, - "input_field_name": input_field_name, - "output_field_name": output_field_name, - }, - remove_columns=[input_field_name, output_field_name], - ) - - -### Utils for loading the data from disk in supported formats [currently only jsonl] -def load_hf_dataset_from_file( - data_path: str, input_field_name: str, output_field_name: str -) -> Dataset: - """Loads the HuggingFace dataset from JSON or JSONL file. - - Args: - data_path: str - Path to the file to be loaded. - input_field_name: str - Name of the input field in the data. - output_field_name: str - Name of the output field in the data. - - Returns: - Dataset - HF Dataset with the data to be tokenized. - """ - if input_field_name == output_field_name: - raise ValueError("Input field name and output field name should not match!") - - def get_json_object(): - with open(data_path, "r", encoding="utf-8") as json_file: - file_extension = os.path.splitext(data_path)[-1].lower() - if file_extension == ".jsonl": - data_stream = (json.loads(line) for line in json_file) - elif file_extension == ".json": - data_stream = json.load(json_file) - else: - raise ValueError("Unsupported file format! Use 'json' or 'jsonl'.") - - for data in data_stream: - yield { - input_field_name: data[input_field_name], - output_field_name: data[output_field_name], - } - - return Dataset.from_generator(get_json_object) - - -### Utils for custom masking / manipulating input / output strs, etc -def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): - """Combines / concatenates input & output element. - - Args: - input_element: str - Input component of the combined sequence. - output_element: str - Output component of the combined sequence. - eos_token: str - EOS token associated with the tokenizer. \ - If passed, it will be concatenated at end - - Returns: - str - Sequence combined with whitespace. - """ - if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith( - (" ", "\n", "\t") - ): - return input_element + " " + output_element + eos_token - return input_element + output_element + eos_token - - -def preprocess_and_tokenize( - element: Dict[str, str], - tokenizer: AutoTokenizer, - max_seq_length: int, - input_field_name: str, - output_field_name: str, -) -> Dict[str, Any]: - """Loads the dataset and applies the tokenizer + custom masking logic. - NOTE: Truncation is done in this step, but padding is not, and generally - handled by the collator. - - Args: - element: Dict[str, str] - A single element of the raw Dataset of strings, whose data we would like to apply - custom masking + tokenization logic to. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - max_sequence_length: int - Max sequence length to be used for sequence tokenization. - input_field_name: str - Name of the input field in the data. - output_field_name: str - Name of the output field in the data. - - Returns: - Dict[str, Any] - Dictionary containing the input IDs/labels/attention mask for this record. - """ - combined_seq = combine_sequence( - element[input_field_name], element[output_field_name], tokenizer.eos_token - ) - - tokenized_comb_seqs = tokenizer( - combined_seq, max_length=max_seq_length, truncation=True, padding=False - ) - tokenized_input = tokenizer( - element[input_field_name], - max_length=max_seq_length, - truncation=True, - padding=False, - ) - - # mask the prompt part for avoiding loss - masked_labels = [-100] * len( - tokenized_input.input_ids - ) + tokenized_comb_seqs.input_ids[len(tokenized_input.input_ids) :] - - return { - "input_ids": tokenized_comb_seqs.input_ids, - "labels": masked_labels, - "attention_mask": tokenized_comb_seqs.attention_mask, - } diff --git a/tuning/utils/utils.py b/tuning/utils/utils.py new file mode 100644 index 000000000..9def53df9 --- /dev/null +++ b/tuning/utils/utils.py @@ -0,0 +1,44 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import json +import os + +# Third Party +import yaml + + +def get_extension(file_path: str) -> str: + _, ext = os.path.splitext(file_path) + return ext.lower() + + +def get_loader_for_filepath(file_path: str) -> str: + ext = get_extension(file_path) + if ext in (".txt", ".md"): + return "text" + if ext in (".json", ".jsonl"): + return "json" + return ext + + +def load_yaml_or_json(file_path: str) -> dict: + with open(file_path, "r", encoding="utf-8") as f: + ext = get_extension(file_path) + if ext in (".yaml", ".yml"): + return yaml.safe_load(f) + if ext == ".json": + return json.load(f) + return None