diff --git a/.github/workflows/build-and-publish.yml b/.github/workflows/build-and-publish.yml index b95bee75..c9080c55 100644 --- a/.github/workflows/build-and-publish.yml +++ b/.github/workflows/build-and-publish.yml @@ -8,14 +8,15 @@ jobs: strategy: matrix: python-version: - - setup: "3.10" - tox: "py310" + - setup: "3.11" + tox: "py311" plugin_name: - "framework" - "accelerated-peft" - "fused-ops-and-kernels" - "attention-and-distributed-packing" - "accelerated-moe" + - "online-data-mixing" permissions: id-token: write # IMPORTANT: this permission is mandatory for trusted publishing diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 8f25a613..d1ee538c 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -31,13 +31,14 @@ jobs: - "fused-ops-and-kernels" - "attention-and-distributed-packing" - "accelerated-moe" + - "online-data-mixing" steps: - uses: actions/checkout@v4 - - name: Set up Python 3.9 + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -60,10 +61,10 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Set up Python 3.9 + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py index d568ec13..252842e0 100644 --- a/plugins/framework/src/fms_acceleration/constants.py +++ b/plugins/framework/src/fms_acceleration/constants.py @@ -21,4 +21,4 @@ # and activated. # - hence the plugins that have model loaders should be on top of this list -PLUGINS = ["peft", "foak", "aadp", "moe"] +PLUGINS = ["peft", "foak", "aadp", "moe", "odm"] diff --git a/plugins/online-data-mixing/.isort.cfg b/plugins/online-data-mixing/.isort.cfg new file mode 100644 index 00000000..7d3762ec --- /dev/null +++ b/plugins/online-data-mixing/.isort.cfg @@ -0,0 +1,10 @@ +[settings] +profile=black +from_first=true +import_heading_future=Future +import_heading_stdlib=Standard +import_heading_thirdparty=Third Party +import_heading_firstparty=First Party +import_heading_localfolder=Local +known_firstparty= +known_localfolder=tuning \ No newline at end of file diff --git a/plugins/online-data-mixing/.pylintrc b/plugins/online-data-mixing/.pylintrc new file mode 100644 index 00000000..4dc16dbc --- /dev/null +++ b/plugins/online-data-mixing/.pylintrc @@ -0,0 +1,649 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,protobufs + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths=.*megablocks,.*khd + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.11 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=8 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1100 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + # Added messages + use-symbolic-message-instead, + invalid-name, + missing-class-docstring, + missing-module-docstring, + missing-function-docstring, + consider-using-f-string, + inconsistent-return-statements, + no-member, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + cyclic-import, + too-few-public-methods, + protected-access, + fixme, + logging-format-interpolation, + logging-too-many-args, + attribute-defined-outside-init, + abstract-method, + pointless-statement, + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking, + unused-argument + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io \ No newline at end of file diff --git a/plugins/online-data-mixing/README.md b/plugins/online-data-mixing/README.md new file mode 100644 index 00000000..57ef136b --- /dev/null +++ b/plugins/online-data-mixing/README.md @@ -0,0 +1,57 @@ +# Online Data Mixing + +This library contains plugin for online dynamic reward (learnable) based data mixing framework that operates on dynamically mixing datasets online during training while being adapted based on the signals (e.g. training loss, gradnorm etc) from training. + +## Plugins + +Plugin | Description | Depends | Loading | Augmentation | Callbacks +--|--|--|--|--|-- +[odm](./src/fms_acceleration_odm/framework_plugin_odm.py) | OnlineMixingDataset PyTorch IterableDataset and custom rewards | | ✅ | ✅ | ✅ + +## Design +![](./artifacts/Design.png) + +## Usage in Custom Training Loop + +![](./artifacts/plot.png) + +`OnlineMixingDataset` can be imported easily and integrated into existing training loops with minimal changes. A sample custom training loop implementation can be found [here](./artifacts/custom_loop_usage.py). Given code sample uses two instruction tuning datasets and trains `ibm-granite/granite-3.1-2b-instruct` model for next token prediction task. + +## Metrics + +All metrics related to the online data mixing will be logged to `odm.jsonl` file in the checkpoint output directory. + +Metric | Description +--|-- +`samples_produced_so_far` | Total samples produced by the dataset so far at the time of logging. +`sampling_interval` | Takes sample count "n" as input. At every "n" steps category/dataset chosen by weighted random sampling where weights are provided by the Multi-Armed Bandit algorithm. +`total_categories` | Total categories or datasets involved in mixing. +`current_sampling_weights` | Current state of the sampling weights at the time of logging. +`current_sampling_ratio` | Current state of the sampling ratios at the time of logging. +`arm_idx` | Last sampled category index. Categories/datasets are sorted in ascending order based on their names and index starts from 0 and each index corresponds to respective category/dataset. +`category_level_counts_so_far` | Split of sample count across datasets so far at the time of logging. +`rewards` | State of the rewards at the time of logging. Essentially are the last provided rewards across datasets. +`action` | Type of action took place at the time logging. It is either "update" or "sample" which correspond to weight update of the MAB algorithm or category sampling. + +## Rewards + +Below are the currently available rewards and we are constantly looking to improve the existing rewards and also add new ones. Further, we encourage users to identify rewards that can help their usecases. + +Rewards | Description +--|-- +`ENTROPY` | Calculation of shannon entropy of the logits averaged across all the tokens. Higher entropy would mean model requires more samples from that datasets/category. +`ENTROPY3_VARENT1` | 3 parts of shannon entropy and 1 part of variance of the entropy. Higher values mean requirement of more samples. +`ENTROPY_LAST_TOKEN` | Shannon entropy of the last token in the sample. Higher values mean requirement of more samples. +`TRAIN_LOSS` | Training loss where loss is maintained across categories and is updated based on the latest loss and sampled dataset/category. Higher values mean requirement of more samples. +`VALIDATION_LOSS` | Validation loss across categories calculated using evaluation datasets from each of the categories. Higher values mean requirement of more samples. +`GRADNORM` | Gradient norm where norms are maintained across categories and are updated based on the latest values and sampled dataset/category. Higher values mean reducing samples from that particular dataset/category. + +### Adding a Custom Reward +Custom rewards can be added to the `compute_reward` function and adding it to the `Reward` enum. If the custom reward requires specific set of information from the training loop then `_extract_information_from_state_for_reward` function has to be extended for extracting such information from trainer state. This is member function of `OnlineMixingDataset`. + + +### Planned TODOs +Please see issue [#153](https://github.com/foundation-model-stack/fms-acceleration/issues/153). + + + diff --git a/plugins/online-data-mixing/artifacts/Design.png b/plugins/online-data-mixing/artifacts/Design.png new file mode 100644 index 00000000..736f5c24 Binary files /dev/null and b/plugins/online-data-mixing/artifacts/Design.png differ diff --git a/plugins/online-data-mixing/artifacts/custom_loop_usage.py b/plugins/online-data-mixing/artifacts/custom_loop_usage.py new file mode 100644 index 00000000..d9802b5e --- /dev/null +++ b/plugins/online-data-mixing/artifacts/custom_loop_usage.py @@ -0,0 +1,145 @@ +# Run commmand +# CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file fms-acceleration/scripts/benchmarks/accelerate.yaml +# --num_processes=2 --main_process_port=29511 custom_loop_usage.py + +# Standard +import json +import os + +# Third Party +from accelerate import Accelerator, DataLoaderConfiguration +from datasets import load_dataset +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForLanguageModeling, +) +import torch + +# First Party +from fms_acceleration_odm import OnlineMixingDataset + +model_name = "ibm-granite/granite-3.1-2b-instruct" +output_dir = "./odm_custom_use" +max_steps = 125 +batch_size = 12 +log_file = os.path.join(output_dir, "loss.jsonl") + +# odm related +step_idx = 0 +update_interval = 1 # every step + +# model +model = AutoModelForCausalLM.from_pretrained(model_name) + +# tokenizer +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token + + +# dataset related +def tokenize_fn(examples): + return tokenizer( + examples["text"], truncation=True, padding="max_length", max_length=128 + ) + + +dataset_dict = { + "alpaca": load_dataset("tatsu-lab/alpaca", split="train[:1%]"), + "oasst": load_dataset("hakurei/open-instruct-v1", split="train[:1%]"), +} + + +def format_example(example): + if "instruction" in example: + prompt = f"Instruction: {example['instruction']}\nInput: {example.get('input','')}\nOutput: {example['output']}" + elif "text" in example: + prompt = example["text"] + return {"text": prompt} + + +for name in dataset_dict: + dataset_dict[name] = dataset_dict[name].map(format_example) + + +def tokenize_fn(examples): + return tokenizer( + examples["text"], + truncation=True, + padding="max_length", + max_length=1024, + ) + + +for name in dataset_dict: + dataset_dict[name] = dataset_dict[name].map( + tokenize_fn, + batched=True, + remove_columns=dataset_dict[name].column_names, + ) + +collator_dict = { + name: DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + for name in dataset_dict +} + +# dataset preparation +dataset = OnlineMixingDataset( + dataset_dict=dataset_dict, + collators_dict=collator_dict, + eval_dataset_dict={}, + eval_collators_dict={}, + output_dir=output_dir, + reward_type="train_loss", + sampling_interval=batch_size, +) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=None) + +# distributed setup +dataloader_config = DataLoaderConfiguration(split_batches=True, dispatch_batches=True) +accelerator = Accelerator(split_batches=True, dataloader_config=dataloader_config) +model, dataloader = accelerator.prepare(model, dataloader) + +# training setup +optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) + + +# Trainer state +class State: + log_history: list = [] + + +state = State() + + +# custom training loop +model.train() +for step, batch in enumerate( + tqdm(dataloader, disable=not accelerator.is_local_main_process) +): + step_idx += 1 + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + loss = accelerator.gather(loss).mean() + if step_idx % 1 == 0: + if torch.isnan(loss): + loss = torch.tensor([10]) # nan -> very high loss + if accelerator.is_main_process: + print(f"Step {step_idx} ||| Loss: {loss.item():.4f}") + with open(log_file, "a") as f: + f.write(json.dumps({"loss": loss.item(), "step": step_idx}) + "\n") + state.log_history.append({"loss": loss.item(), "step": step_idx}) + if step_idx % update_interval == 0: + with torch.no_grad(): + model.eval() + dataloader.dataset.update_sampling_weights(model, accelerator, state) + model.train() + if step_idx > max_steps: + break + +print("Training completed!") diff --git a/plugins/online-data-mixing/artifacts/plot.png b/plugins/online-data-mixing/artifacts/plot.png new file mode 100644 index 00000000..d32a337b Binary files /dev/null and b/plugins/online-data-mixing/artifacts/plot.png differ diff --git a/plugins/online-data-mixing/configs/odm.yaml b/plugins/online-data-mixing/configs/odm.yaml new file mode 100644 index 00000000..5f7f19c4 --- /dev/null +++ b/plugins/online-data-mixing/configs/odm.yaml @@ -0,0 +1,8 @@ +training: + odm: + odm: + update_interval: 1 # update every step + sampling_interval: 1 # sample category for every sample + reward_type: entropy # type of reward to use + gamma: 0.1 # MAB hyper-parameter + eta: 0.1 # MAB hyper-parameter diff --git a/plugins/online-data-mixing/pyproject.toml b/plugins/online-data-mixing/pyproject.toml new file mode 100644 index 00000000..bd97cce3 --- /dev/null +++ b/plugins/online-data-mixing/pyproject.toml @@ -0,0 +1,31 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fms-acceleration-odm" +version = '0.1.0.dev' +description = "FMS Acceleration plugin for online data mixing" +authors = [ + {name = "Mehant Kammakomati", email = "mehant.kammakomati2@ibm.com"}, + {name = "Romit Jain", email = "romit@ibm.com"}, + {name = "Padmanabha Venkatagiri Seshadri", email = "seshapad@in.ibm.com"}, +] +license = {text = "Apache-2.0"} +readme = "README.md" +requires-python = "~=3.11" +keywords = ['fms-hf-tuning', 'acceleration', 'online-data-mixing'] +classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", +] + +dependencies = ["datasets"] + +[tool.hatch.build.targets.wheel] +only-include = ["src/fms_acceleration_odm"] + +[tool.hatch.build.targets.wheel.sources] +"src" = "" diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/__init__.py b/plugins/online-data-mixing/src/fms_acceleration_odm/__init__.py new file mode 100644 index 00000000..8d6e919c --- /dev/null +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/__init__.py @@ -0,0 +1,19 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Local +from .framework_plugin_odm import OnlineDataMixingAccelerationPlugin +from .odm import OnlineMixingDataset, Reward, compute_reward +from .patch import patch_hf_trainer_evaluate diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/framework_plugin_odm.py b/plugins/online-data-mixing/src/fms_acceleration_odm/framework_plugin_odm.py new file mode 100644 index 00000000..9f1296a5 --- /dev/null +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/framework_plugin_odm.py @@ -0,0 +1,77 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Dict, Tuple + +# Third Party +from fms_acceleration import AccelerationPlugin +from peft import LoraConfig +from transformers import TrainingArguments +import torch + +# Local +from .patch import patch_hf_trainer_evaluate + + +# pylint: disable=too-many-instance-attributes +class OnlineDataMixingAccelerationPlugin(AccelerationPlugin): + + def __init__(self, configurations: Dict[str, Dict]): + super().__init__(configurations) + + self._update_interval = self._check_config_and_maybe_check_values( + key="training.odm.odm.update_interval", + default=1, + ) + + # data_config file should be there + @property + def requires_augmentation(self): + return True + + def augmentation( + self, + model, + train_args: TrainingArguments, + modifiable_args: Tuple[LoraConfig], + ): + # original user intended eval steps is preserved in the model object + # while we overwrite the training args eval_steps and strategy to 1 and steps + # since that way eval pipeline is always triggered and is patched for controlled + # usage for ODM dataloader update action + model.ta_eval_steps = train_args.eval_steps + train_args.eval_steps = 1 + train_args.eval_strategy = "steps" + + # update_interval information has to be made available in the evaluate HF patch + # function and this seems to be the only reasonable way to do so + model.ta_update_interval = self._update_interval + return model, modifiable_args + + def get_callbacks_and_ready_for_train( + self, model: torch.nn.Module = None, accelerator=None + ): + callbacks = [] + patch_hf_trainer_evaluate() + return callbacks + + +# register +AccelerationPlugin.register_plugin( + OnlineDataMixingAccelerationPlugin, + configuration_and_paths=[ + "training.odm.odm", + ], +) diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/__init__.py b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/__init__.py new file mode 100644 index 00000000..a66ebfe2 --- /dev/null +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/__init__.py @@ -0,0 +1,18 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Local +from .dataset import OnlineMixingDataset +from .reward import Reward, compute_reward diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py new file mode 100644 index 00000000..5d01f2bc --- /dev/null +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py @@ -0,0 +1,411 @@ +# Standard +from logging import getLogger +from typing import List, Optional +import json +import math +import os +import random + +# Third Party +from datasets import DatasetDict +from torch.utils.data import DataLoader, IterableDataset +from tqdm import tqdm +import torch + +# Local +from .reward import Reward, compute_reward + +logger = getLogger(__name__) + + +# pylint: disable=too-many-instance-attributes +class OnlineMixingDataset(IterableDataset): + def __init__( + self, + dataset_dict: DatasetDict, + collators_dict: dict, + eval_dataset_dict: DatasetDict, + eval_collators_dict: dict, + sampling_weights: Optional[List[float]] = None, + gamma: float = 0.1, + eta: float = 0.3, + sampling_interval: int = 1, + eval_batch_size: int = 5, + output_dir="odm", + reward_type=Reward.ENTROPY, + ): + """Mixes datasets with sampling ratios learnt using + Multi Armed Bandit (MAB) EXP3 and rewards defined. + Rewards are defined in the compute_reward() function. + + NOTE: In distributed setting, this dataset should be used to + sample on the main process and distribute respective batches + to other worker processes. + + Args: + dataset_dict (DatasetDict): keys are category names and values are HF datasets. + collators_dict (dict): collator corresponding to each dataset + used while constructing torch dataloader. + eval_dataset_dict (DatasetDict): keys are category names and values are HF + eval datasets. + eval_collators_dict (dict): collator corresponding to each dataset + used while constructing torch dataloader. + sampling_weights (Optional[List[float]], optional): Initial + set of sampling weights to start with. Defaults to equal weightage. + gamma (float, optional): MAB hyperparameter. Defaults to 0.1. + eta (float, optional): MAB hyperparameter. Defaults to 0.3. + sampling_interval (int, optional): sample category at every n samples. + Defaults to every sample. + eval_batch_size (int, optional): eval batch size. Defaults to 5. + output_dir (str, optional): output dir to store logs. Defaults to "odm". + reward_type (_type_, optional): type of reward to use, more details can + be found in compute_reward function. Defaults to Reward.ENTROPY. + """ + logger.info( + """Values set to OnlineMixingDataset + dataset_dict: {dataset_dict} + collators_dict: {collators_dict} + eval_dataset_dict: {eval_dataset_dict} + eval_collators_dict:{eval_collators_dict} + sampling_weights: {sampling_weights} + gamma: {gamma} + eta: {eta} + sampling_interval: {sampling_interval} + eval_batch_size: {eval_batch_size} + output_dir: {output_dir} + reward_type: {reward_type} + """.format( + dataset_dict=dataset_dict, + collators_dict=collators_dict, + eval_dataset_dict=eval_dataset_dict, + eval_collators_dict=eval_collators_dict, + sampling_weights=sampling_weights, + gamma=gamma, + eta=eta, + sampling_interval=sampling_interval, + eval_batch_size=eval_batch_size, + output_dir=output_dir, + reward_type=reward_type, + ) + ) + + # gamma and eta are MAB hyper-parameters + self.gamma = gamma + self.eta = eta + self.sampling_interval = sampling_interval + self.collators_dict = collators_dict + self.eval_collators_dict = eval_collators_dict + self.eval_dataset_dict = eval_dataset_dict + self.eval_dataset_dict_dl = {} + self.train_dataset_dict_dl = {} + # prepare torch dataloaders for each of the dataset. + for k, _ in dataset_dict.items(): + dataset_dict[k] = DataLoader( + dataset_dict[k], + 1, + shuffle=False, + num_workers=1, + collate_fn=collators_dict[k] if collators_dict else None, + ) + self.train_dataset_dict_dl[k] = iter(dataset_dict[k]) + self.eval_batch_size = eval_batch_size + self.dataset_dict = dataset_dict + self.category_list = sorted(self.train_dataset_dict_dl.keys()) + self.id2cat = dict(enumerate(self.category_list)) + self.cat2id = {c: i for i, c in enumerate(self.category_list)} + self.total_categories = len(self.category_list) + + # If not starting weights given, then all arms (categories) + # are equally important. Weights based on the size of the datasets + # and other such heuristics should be computed outside and passed + # through sampling_weights while initializing this class. + if sampling_weights is None: + sampling_weights = [1] * self.total_categories + self.sampling_weights = torch.tensor(sampling_weights, dtype=torch.float64) + self.sampling_ratio = [] + self._update_sampling_ratio(self.sampling_weights) + + # curr_cat_count is current sample count per category + self.curr_cat_count = [0] * self.total_categories + + # produced is total samples sampled so far + self.produced = 0 + + # currently active category (arm) + self.arm_idx = 0 + + # should be one of Reward + self.reward_type = reward_type + if isinstance(self.reward_type, str): + self.reward_type = self.reward_type.upper() + self.reward_type = Reward[self.reward_type] + self.output_dir = output_dir + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + self.log_file_path = os.path.join(self.output_dir, "odm.jsonl") + logger.info( + "Logs for online data mixing to be stored at {log_file_path}".format( + log_file_path=self.log_file_path + ) + ) + self.log = { + "samples_produced_so_far": 0, + "sampling_interval": self.sampling_interval, + "total_categories": self.total_categories, + "current_sampling_weights": self.sampling_weights.tolist(), + "current_sampling_ratio": self.sampling_ratio, + "arm_idx": self.arm_idx, + "category_level_counts_so_far": self.curr_cat_count, + "rewards": [0] * self.total_categories, + "count": 0, + "action": "", # one of sample or update + } + + def log_to_file(self, data: dict): + """helper function to log the state to the file + + Args: + data (dict): log state updates + """ + self.log.update(data) + with open(self.log_file_path, "a", encoding="utf-8") as f: + f.write(json.dumps(self.log) + "\n") + + def __iter__(self): + self.produced = 0 + return self + + def __next__(self): + if self.produced % self.sampling_interval == 0: + self.arm_idx = random.choices( + range(self.total_categories), weights=self.sampling_ratio, k=1 + )[0] + sample = None + try: + sample = next(self.train_dataset_dict_dl[self.id2cat[self.arm_idx]]) + except StopIteration: + logger.info( + "{id} dataset exhausted so the iterator is reset.".format( + id=self.id2cat[self.arm_idx] + ) + ) + self.train_dataset_dict_dl[self.id2cat[self.arm_idx]] = iter( + self.dataset_dict[self.id2cat[self.arm_idx]] + ) + sample = next(self.train_dataset_dict_dl[self.id2cat[self.arm_idx]]) + + self.curr_cat_count[self.arm_idx] += 1 + self.produced += 1 + + # dataloader returns a batch of 1 sample + # next should return single sample rather a batch + if isinstance(sample, torch.Tensor): + # (edge case) when no collators are passed + sample = { + "input_ids": sample[0], + "attention_mask": torch.ones_like(sample[0]), + "labels": sample[0], + } + else: + sample = { + "input_ids": sample["input_ids"][0], + "attention_mask": ( + sample["attention_mask"][0] + if "attention_mask" in sample + else torch.ones_like(sample["input_ids"][0]) + ), + "labels": ( + sample["labels"][0] + if "labels" in sample + else sample["input_ids"][0] + ), + } + + self.log_to_file( + { + "arm_idx": self.arm_idx, + "samples_produced_so_far": self.produced, + "category_level_counts_so_far": self.curr_cat_count, + "action": "sample", + } + ) + return sample + + def _reset_eval_dataloaders(self): + """Helper function to reset eval dataloaders since + they would be exhausted in the previous evaluation loop. + """ + self.eval_dataset_dict_dl = {} + for k, _ in self.eval_dataset_dict.items(): + # this can be improved with persistent workers and caching + # dataloaders and resetting them when needed. + self.eval_dataset_dict_dl[k] = ( + iter( + DataLoader( + self.eval_dataset_dict[k], + self.eval_batch_size, + shuffle=False, + num_workers=1, + collate_fn=( + self.eval_collators_dict[k] + if self.eval_collators_dict + else None + ), + ) + ) + if self.eval_dataset_dict[k] + else None + ) + + def _update_sampling_ratio(self, weights) -> list: + """Helper function to convert weights to ratio + + Args: + weights: sampling weights + + Returns: + list: sampling ratio + """ + w = weights + w_sum = w.sum() + K = len(w) + base = (1.0 - self.gamma) * (w / w_sum) + expl = self.gamma / K + self.sampling_ratio = (base + expl).tolist() + return self.sampling_ratio + + def _update_weights(self, count, rewards) -> list: + """Helper function to update MAB weights with rewards + + Args: + count: size of total number of categories with count of samples per category + rewards: same size of count with reward of samples per category + + Returns: + list: sampling ratio + """ + + for arm in range(self.total_categories): + avg_r = rewards[arm] / count[arm] + est_r = avg_r / self.sampling_ratio[arm] + self.sampling_weights[arm] *= math.exp( + self.eta * est_r / self.total_categories + ) + return self._update_sampling_ratio(self.sampling_weights) + + def _extract_information_from_state_for_reward(self, state=None, category=None): + """Helper function to extract exact information that the reward computation + can consume. This function has to be expanded for new rewards. + + Args: + state: HF TrainerState object. Defaults to None. + + Returns: + dict: arguments prepared for compute_reward function + """ + if state is None: + return {} + if self.reward_type.startswith(Reward.ENTROPY): + return {} + if self.reward_type == Reward.TRAIN_LOSS: + return {"train_loss_history": [d for d in state.log_history if "loss" in d]} + if self.reward_type == Reward.VALIDATION_LOSS: + assert category is not None + return { + "eval_loss_history": [ + {"loss": d[f"eval_{category}_loss"], **d} + for d in state.log_history + if f"eval_{category}_loss" in d + ] + } + if self.reward_type == Reward.GRADNORM: + return { + "gradnorm_history": [d for d in state.log_history if "grad_norm" in d] + } + return {} + + def update_sampling_weights(self, model, accelerator, state): + """Function to update MAB weights based on the reward type provided + during the initialization. This function has to be updated if adding + new reward types and based on their information needs from training loop. + + Args: + model: HF model object. Conversion of the model (train to inference mode) + is NOT the responsibility of this function. + accelerator: Accelerate object, used for distributed operations. + Should be None of single GPU runs. + TODO: There is a hard dependency on accelerator which would be relaxed + in future versions. + state: HF TrainerState object (other formats will be supported in the future). + For custom loop, please prepare your state class following TrainerState class. + """ + rewards = [0] * self.total_categories + count = [0] * self.total_categories + eval_dataset_dict = {} + device = accelerator.device if accelerator else torch.device(0) + self._reset_eval_dataloaders() + for c in range(self.total_categories): + # accelerator takes care of preparing the eval dataloaders for distributed inference. + if accelerator: + eval_dataset_dict[self.id2cat[c]] = ( + accelerator.prepare(self.eval_dataset_dict_dl[self.id2cat[c]]) + if self.eval_dataset_dict_dl.get(self.id2cat[c], None) + else None + ) + else: + eval_dataset_dict[self.id2cat[c]] = self.eval_dataset_dict_dl.get( + self.id2cat[c], None + ) + for c in tqdm( + range(self.total_categories), total=self.total_categories, desc="Categories" + ): # for trian loss you dont need to iterate over eval dataset. + if not eval_dataset_dict[self.id2cat[c]]: + rc = compute_reward( + model=model, + batch=None, + vocab_size=32000, + reward_type=self.reward_type, + current_category=c, + total_categories=self.total_categories, + last_sampled_category=self.arm_idx, + **self._extract_information_from_state_for_reward( + state, self.id2cat[c] + ), + ) + rewards[c] += rc + count[c] += 1 + else: + for batch in tqdm( + eval_dataset_dict[self.id2cat[c]], + desc="Reward computation over eval dataset", + ): + rc = compute_reward( + model=model, + batch={k: v.to(device) for k, v in batch.items()}, + vocab_size=32000, + reward_type=self.reward_type, + current_category=c, + total_categories=self.total_categories, + last_sampled_category=self.arm_idx, + **self._extract_information_from_state_for_reward( + state, self.id2cat[c] + ), + ) + rewards[c] += rc + count[c] += batch["input_ids"].shape[0] + rewards = torch.tensor(rewards, device=device) + count = torch.tensor(count, device=device) + if accelerator: + rewards = accelerator.reduce(rewards, reduction="sum") + count = accelerator.reduce(count, reduction="sum") + if accelerator.is_main_process: + self._update_weights(count, rewards) + self.log_to_file( + { + "current_sampling_weights": self.sampling_weights.tolist(), + "current_sampling_ratio": self.sampling_ratio, + "rewards": rewards.tolist(), + "count": count.tolist(), + "action": "update", + } + ) diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/reward.py b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/reward.py new file mode 100644 index 00000000..3404dd5d --- /dev/null +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/reward.py @@ -0,0 +1,148 @@ +# samples entropy + +# Standard +from enum import StrEnum, auto +from typing import Dict + +# Third Party +from transformers import PreTrainedModel +import torch +import torch.nn.functional as F + + +class Reward(StrEnum): + ENTROPY = auto() + ENTROPY3_VARENT1 = auto() + ENTROPY_LAST_TOKEN = auto() + TRAIN_LOSS = auto() + VALIDATION_LOSS = auto() + GRADNORM = auto() + + +# TODO: +# 1. class implementation to store individual reward states +# 2. sub arguments for rewards for user to pass +# 3. use reward state and expected inputs to control +# flow of information (iteration over eval batch vs not doing so) + +TRAIN_LOSS_DATA = {"buffer": []} + +EVAL_LOSS_DATA = {"buffer": []} + +GRADNORM_DATA = {"buffer": []} + + +def compute_reward( + model: PreTrainedModel, + batch: Dict[str, torch.Tensor], + vocab_size: int, + reward_type: Reward, + train_loss_history=None, + eval_loss_history=None, + gradnorm_history=None, + last_sampled_category=None, + total_categories=None, + current_category=None, +) -> float: + """ + Compute rewards based on the provided reward_type. + You should be extending this function for new rewards. + + Supported rewards: + + Entropy related rewards: ENTROPY, ENTROPY3_VARENT1 & ENTROPY_LAST_TOKEN + Calculates the entropy and variance of entropy of every sequence in the batch. + For every sequence, + 1. The token level metrics are computed + 2. The metrics are averaged per sequence after applying the attention mask + + Train loss reward: TRAIN_LOSS + We maintain a buffer over all the categories capturing their train loss when sampled. + Higher train loss should reward more to choose from that category to bring this loss down. + + Validation loss reward: VALIDATION_LOSS + Similar to TRAIN_LOSS reward here we use individual category validation loss instead. + + Grad norm reward: GRADNORM + Similar to TRAIN_LOSS reward here we use overall gradnorm. However, Higher grad norm + categories should be less priortized. + + Args: + model (PreTrainedModel): HF Model object + batch (torch.Tensor): Batch of samples (input_ids, labels, attention_mask) + vocab_size (int): Maximum vocab size of the model used by ENTROPY rewards + reward_type (Reward): Type of the reward + train_loss_history: list of dicts each holding information on the training loss + eval_loss_history: list of dicts each holding information on the eval loss + gradnorm_history: list of dicts each holding information on the grad_norm + last_sampled_category: index of the last sampled category + total_categories: total number of categories + current_category: currently being reward computed category + Returns: + float + """ + if reward_type.startswith(Reward.ENTROPY): + with torch.inference_mode(): + outputs = model(**batch) + shift_logits = outputs.logits[:, :-1, :] + + log_probs = F.log_softmax(shift_logits, dim=-1) + probs = torch.exp(log_probs) + + entropy = -torch.sum(probs * log_probs, dim=-1) + sum_p_log_sq = torch.sum(probs * (log_probs**2), dim=-1) + varentropy = sum_p_log_sq - (entropy**2) + + entropy_last_token = entropy[:, -1] + + mask = batch["attention_mask"][:, 1:] + + entropy = (entropy * mask).sum(dim=-1) / mask.sum(dim=-1) + varentropy = (varentropy * mask).sum(dim=-1) / mask.sum(dim=-1) + + max_entropy = torch.log( + torch.tensor(vocab_size, dtype=entropy.dtype, device=entropy.device) + ) + + entropy = (entropy / max_entropy).clamp(0.0, 1.0) + varentropy = (varentropy / max_entropy**2).clamp(0.0, 1.0) + entropy_last_token = (entropy_last_token / max_entropy).clamp(0.0, 1.0) + if reward_type == Reward.ENTROPY: + return entropy.sum().item() + if reward_type == Reward.ENTROPY3_VARENT1: + return 0.75 * entropy.sum().item() + 0.25 * varentropy.sum().item() + if reward_type == Reward.ENTROPY_LAST_TOKEN: + return entropy_last_token.sum().item() + if reward_type == Reward.TRAIN_LOSS: + if not train_loss_history: + raise ValueError("train_loss_history cannot be a empty list or None") + if not TRAIN_LOSS_DATA["buffer"]: + TRAIN_LOSS_DATA["buffer"] = [1e-100] * total_categories + TRAIN_LOSS_DATA["buffer"][last_sampled_category] = train_loss_history[-1][ + "loss" + ] + return TRAIN_LOSS_DATA["buffer"][current_category] + if reward_type == Reward.VALIDATION_LOSS: + if not eval_loss_history: + raise ValueError( + "eval_loss_history cannot be a empty list or None." + "Make sure you are using eval_strategy and eval_steps" + "allowing atleast 1 evaluation before reward computation." + ) + if not EVAL_LOSS_DATA["buffer"]: + EVAL_LOSS_DATA["buffer"] = [1e-100] * total_categories + EVAL_LOSS_DATA["buffer"][current_category] = eval_loss_history[-1]["loss"] + return EVAL_LOSS_DATA["buffer"][current_category] + if reward_type == Reward.GRADNORM: + if not gradnorm_history: + raise ValueError( + "gradnorm_history cannot be a empty list or None." + "Make sure grad norm is made available." + ) + if not GRADNORM_DATA["buffer"]: + GRADNORM_DATA["buffer"] = [1e-100] * total_categories + GRADNORM_DATA["buffer"][last_sampled_category] = 1 / ( + gradnorm_history[-1]["grad_norm"] + 0.0001 + ) + return GRADNORM_DATA["buffer"][current_category] + raise TypeError(f"Reward {reward_type} not supported") diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/patch.py b/plugins/online-data-mixing/src/fms_acceleration_odm/patch.py new file mode 100644 index 00000000..7e52cec1 --- /dev/null +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/patch.py @@ -0,0 +1,107 @@ +# fms-hf-tuning patch +# Standard +from logging import getLogger + +# Third Party +from transformers import Trainer + +logger = getLogger(__name__) + + +def patch_hf_trainer_evaluate(): + # Third Party + # pylint: disable=import-outside-toplevel + from fms_acceleration.model_patcher import patch_target_module + + Trainer._evaluate = _evaluate + patch_target_module("transformers.trainer.Trainer", Trainer) + + +def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False): + # Standard + # pylint: disable=import-outside-toplevel + import time + + # Third Party + # pylint: disable=import-outside-toplevel + import torch + + metrics = None + if ( + self.model.ta_eval_steps + and self.state.global_step % self.model.ta_eval_steps == 0 + ): + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + # Run delayed LR scheduler now that metrics are populated + if ( + isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + and not skip_scheduler + ): + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + try: + self.lr_scheduler.step(metrics[metric_to_check]) + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is " + f"set to '{metric_to_check}', " + f"which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}." + f"Please ensure that the `compute_metrics` function returns a " + f"dictionary that includes '{metric_to_check}' or " + f"consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + if self.state.global_step % self.model.ta_update_interval == 0: + # prepare model + # code taken from def evaluation_loop from HF + model = self._wrap_model(self.model, training=False) + args = self.args + if len(self.accelerator._models) == 0 and model is self.model: + start_time = time.time() + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + or ( + self.is_fsdp_enabled + and self.accelerator.mixed_precision != "fp8" + and not self.args.torch_compile + ) + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + self.model_preparation_time = round(time.time() - start_time, 4) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, + # whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + if hasattr(model, "eval") and callable(model.eval): + model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() + # Do this before wrapping. + if args.past_index >= 0: + self._past = None + # prepare dataloader + self.train_dataset.update_sampling_weights(model, self.accelerator, self.state) + + return metrics diff --git a/plugins/online-data-mixing/tests/__init__.py b/plugins/online-data-mixing/tests/__init__.py new file mode 100644 index 00000000..38a9531e --- /dev/null +++ b/plugins/online-data-mixing/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/online-data-mixing/tests/test_compute_reward.py b/plugins/online-data-mixing/tests/test_compute_reward.py new file mode 100644 index 00000000..39d5dd2a --- /dev/null +++ b/plugins/online-data-mixing/tests/test_compute_reward.py @@ -0,0 +1,183 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Third Party +from transformers import AutoModelForCausalLM + +# pylint: disable=import-error +import pytest +import torch + +# First Party +from fms_acceleration_odm import Reward, compute_reward + +PARAMETERS = [ + ("Maykeye/TinyLLama-v0", "ENTROPY", None, None, None, 1, 8, 0, 3), + ("Maykeye/TinyLLama-v0", "ENTROPY3_VARENT1", None, None, None, 1, 8, 0, 2.5), + ("Maykeye/TinyLLama-v0", "ENTROPY_LAST_TOKEN", None, None, None, 1, 8, 0, 3), + ( + "Maykeye/TinyLLama-v0", + "TRAIN_LOSS", + [ + [{"loss": 3}], + [{"loss": 3}, {"loss": 4}], + [{"loss": 3}, {"loss": 4}, {"loss": 2}], + [{"loss": 3}, {"loss": 4}, {"loss": 2}, {"loss": 1}], + ], + None, + None, + [0, 0, 1, 1], + 2, + None, + [[3, 1e-100], [4, 1e-100], [4, 2], [4, 1]], + ), + ( + "Maykeye/TinyLLama-v0", + "VALIDATION_LOSS", + None, + [ + [[{"loss": 3}], [{"loss": 10}]], + [[{"loss": 3}], [{"loss": 5}]], + [[{"loss": 4}], [{"loss": 3}]], + [[{"loss": 2}], [{"loss": 3}]], + ], + None, + [0, 1, 1, 1], + 2, + None, + [[3, 10], [3, 5], [4, 3], [2, 3]], + ), + ( + "Maykeye/TinyLLama-v0", + "GRADNORM", + None, + None, + [ + [{"grad_norm": 3}], + [{"grad_norm": 3}, {"grad_norm": 4}], + [{"grad_norm": 3}, {"grad_norm": 4}, {"grad_norm": 2}], + [{"grad_norm": 3}, {"grad_norm": 4}, {"grad_norm": 2}, {"grad_norm": 1}], + ], + [0, 1, 1, 0], + 2, + None, + [ + [1 / (3 + 0.0001), 1e-100], + [1 / (3 + 0.0001), 1 / (4 + 0.0001)], + [1 / (3 + 0.0001), 1 / (2 + 0.0001)], + [1 / (1 + 0.0001), 1 / (2 + 0.0001)], + ], + ), +] + + +@pytest.mark.parametrize( + ( + "model,reward_type,train_loss_history,eval_loss_history,gradnorm_history," + "last_sampled_category,total_categories,current_category,reward" + ), + PARAMETERS, +) +def test_compute_reward( + model, + reward_type, + train_loss_history, + eval_loss_history, + gradnorm_history, + last_sampled_category, + total_categories, + current_category, + reward, +): + loaded_model = AutoModelForCausalLM.from_pretrained(model) + reward_type = Reward[reward_type] + batch_size = 3 + seq_length = 6 + vocab_size = 50 + input_ids = ( + torch.arange(batch_size * seq_length).reshape(batch_size, seq_length) + % vocab_size + ) + attention_mask = torch.tensor( + [[1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1]] + ) + labels = input_ids + batch = {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask} + if reward_type == Reward.ENTROPY: + reward = compute_reward( + model=loaded_model, + batch=batch, + vocab_size=vocab_size, + reward_type=reward_type, + current_category=current_category, + total_categories=total_categories, + last_sampled_category=last_sampled_category, + ) + assert reward == 3, f"entropy {reward} does not match the expected value." + if reward_type == Reward.ENTROPY3_VARENT1: + reward = compute_reward( + model=loaded_model, + batch=batch, + vocab_size=vocab_size, + reward_type=reward_type, + current_category=current_category, + total_categories=total_categories, + last_sampled_category=last_sampled_category, + ) + assert reward >= 2.5, f"entropy {reward} does not match the expected value." + if reward_type == Reward.ENTROPY_LAST_TOKEN: + reward = compute_reward( + model=loaded_model, + batch=batch, + vocab_size=vocab_size, + reward_type=reward_type, + current_category=current_category, + total_categories=total_categories, + last_sampled_category=last_sampled_category, + ) + assert reward == 3, f"entropy {reward} does not match the expected value." + if reward_type == Reward.TRAIN_LOSS: + for h, cc, r in zip( + train_loss_history, range(len(last_sampled_category)), reward + ): + returned_reward = [1e-100, 1e-100] + for c in range(total_categories): + returned_reward[c] = compute_reward( + model=loaded_model, + batch=batch, + vocab_size=vocab_size, + reward_type=reward_type, + current_category=c, + total_categories=total_categories, + last_sampled_category=last_sampled_category[cc], + train_loss_history=h, + ) + assert returned_reward == r, f"expected {r} but got {returned_reward}" + if reward_type == Reward.GRADNORM: + for h, cc, r in zip( + gradnorm_history, range(len(last_sampled_category)), reward + ): + returned_reward = [1e-100, 1e-100] + for c in range(total_categories): + returned_reward[c] = compute_reward( + model=loaded_model, + batch=batch, + vocab_size=vocab_size, + reward_type=reward_type, + current_category=c, + total_categories=total_categories, + last_sampled_category=last_sampled_category[cc], + gradnorm_history=h, + ) + assert returned_reward == r, f"expected {r} but got {returned_reward}" diff --git a/plugins/online-data-mixing/tests/test_odm_plugin.py b/plugins/online-data-mixing/tests/test_odm_plugin.py new file mode 100644 index 00000000..aabbb3bf --- /dev/null +++ b/plugins/online-data-mixing/tests/test_odm_plugin.py @@ -0,0 +1,34 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import os + +# Third Party +from fms_acceleration.utils import instantiate_framework, read_configuration + +# First Party +from fms_acceleration_odm import OnlineDataMixingAccelerationPlugin + +# configuration +DIRNAME = os.path.dirname(__file__) +CONFIG_PATH_SCATTERMOE = os.path.join(DIRNAME, "../configs/odm.yaml") + + +def test_framework_installs_odm_plugin(): + with instantiate_framework( + read_configuration(CONFIG_PATH_SCATTERMOE), require_packages_check=False + ) as framework: + for plugin in framework.active_plugins: + assert isinstance(plugin[1], OnlineDataMixingAccelerationPlugin) diff --git a/plugins/online-data-mixing/tests/test_online_data.py b/plugins/online-data-mixing/tests/test_online_data.py new file mode 100644 index 00000000..04589d1e --- /dev/null +++ b/plugins/online-data-mixing/tests/test_online_data.py @@ -0,0 +1,84 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Third Party +# pylint: disable=import-error +import pytest +import torch + +# First Party +from fms_acceleration_odm import OnlineMixingDataset, Reward + +PARAMETERS = [ + ( + [1, 100, 2], + [[1, 100, 1], [1, 200, 1], [1, 100, 1], [1, 1, 1000], [1, 1, 2000]], + 5, + [1, 1, 1, 2, 2], + 3, + ) +] + + +@pytest.mark.parametrize( + "sampling_weights,rewards,batch_size,expected_arm_idx,total_categories", + PARAMETERS, +) +def test_online_data_mix_learning( + sampling_weights, rewards, batch_size, expected_arm_idx, total_categories +): + batch_size = 100 + seq_length = 6 + vocab_size = 50 + input_ids = ( + torch.arange(batch_size * seq_length).reshape(batch_size, seq_length) + % vocab_size + ) + attention_mask = torch.tensor( + [[1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1]] + ) + labels = input_ids + train_data = { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + } + eval_data = { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + } + dataset = OnlineMixingDataset( + train_data, + None, + eval_data, + None, + sampling_weights, + 0.1, + 0.3, + 1, + batch_size, + output_dir="odm", + reward_type=Reward.ENTROPY, + ) + categories_chosen = [] + for reward in rewards: + dataset._update_weights([batch_size] * total_categories, reward) + next(dataset) + categories_chosen.append(dataset.arm_idx) + # we check if atleast half of the choices match since this is probabilistic + # and may fail unit tests randomly + assert sum(x == y for x, y in zip(categories_chosen, expected_arm_idx)) >= ( + len(expected_arm_idx) / 2 + ), "Not even half of the choices were correct" diff --git a/plugins/online-data-mixing/tox.ini b/plugins/online-data-mixing/tox.ini new file mode 100644 index 00000000..1a21a899 --- /dev/null +++ b/plugins/online-data-mixing/tox.ini @@ -0,0 +1,50 @@ +[tox] +envlist = py, lint + +[testenv] +deps = + pytest>=7 + importlib-metadata + -e {toxinidir} +skip_install = true +commands = + + # install the dependencies here to ensure + # the order + pip install -e {toxinidir}/../framework + pytest {posargs:tests} + +[testenv:lint] +description = run linters +skip_install = false +deps = + -e {toxinidir}/../framework + pylint>=2.16.2,<=3.1.0 + datasets() +commands = + pylint src tests +allowlist_externals = pylint + +[testenv:fmt] +description = format +skip_install = true +deps = + black>=22.12 + isort>=5.11 +commands = + black {posargs:.} + isort {posargs:.} + +[testenv:build] +description = build wheel +deps = + build +commands = python -m build -w +skip_install = True + +[testenv:twinecheck] +description = check wheel +deps = + twine +commands = twine check dist/* +skip_install = True \ No newline at end of file