diff --git a/.github/workflows/build-and-publish.yml b/.github/workflows/build-and-publish.yml
index 307ade0e..9592fcfb 100644
--- a/.github/workflows/build-and-publish.yml
+++ b/.github/workflows/build-and-publish.yml
@@ -15,6 +15,7 @@ jobs:
- "accelerated-peft"
- "fused-ops-and-kernels"
- "attention-and-distributed-packing"
+ - "accelerated-moe"
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index 90f7210a..441a84cd 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -30,6 +30,7 @@ jobs:
- "accelerated-peft"
- "fused-ops-and-kernels"
- "attention-and-distributed-packing"
+ - "accelerated-moe"
steps:
- uses: actions/checkout@v4
diff --git a/README.md b/README.md
index 1158550c..8bc4b974 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,7 @@ Plugin | Description | Depends | License | Status
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Alpha
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
[attention-and-distributed-packing](./plugins/attention-and-distributed-packing/README.md) | Padding-Free Flash Attention Computation | flash-attn | Apache 2.0 | Beta
- MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon
+[accelerated-moe](./plugins/accelerated-moe/README.md) | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Beta
## Usage with FMS HF Tuning
diff --git a/plugins/accelerated-moe/.isort.cfg b/plugins/accelerated-moe/.isort.cfg
new file mode 100644
index 00000000..7d3762ec
--- /dev/null
+++ b/plugins/accelerated-moe/.isort.cfg
@@ -0,0 +1,10 @@
+[settings]
+profile=black
+from_first=true
+import_heading_future=Future
+import_heading_stdlib=Standard
+import_heading_thirdparty=Third Party
+import_heading_firstparty=First Party
+import_heading_localfolder=Local
+known_firstparty=
+known_localfolder=tuning
\ No newline at end of file
diff --git a/plugins/accelerated-moe/.pylintrc b/plugins/accelerated-moe/.pylintrc
new file mode 100644
index 00000000..14a7a572
--- /dev/null
+++ b/plugins/accelerated-moe/.pylintrc
@@ -0,0 +1,649 @@
+[MAIN]
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Clear in-memory caches upon conclusion of linting. Useful if running pylint
+# in a server-like mode.
+clear-cache-post-run=no
+
+# Load and enable all available extensions. Use --list-extensions to see a list
+# all available extensions.
+#enable-all-extensions=
+
+# In error mode, messages with a category besides ERROR or FATAL are
+# suppressed, and no reports are done by default. Error mode is compatible with
+# disabling specific errors.
+#errors-only=
+
+# Always return a 0 (non-error) status code, even if lint errors are found.
+# This is primarily useful in continuous integration scripts.
+#exit-zero=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code.
+extension-pkg-allow-list=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
+# for backward compatibility.)
+extension-pkg-whitelist=
+
+# Return non-zero exit code if any of these messages/categories are detected,
+# even if score is above --fail-under value. Syntax same as enable. Messages
+# specified are enabled, while categories only check already-enabled messages.
+fail-on=
+
+# Specify a score threshold under which the program will exit with error.
+fail-under=10
+
+# Interpret the stdin as a python script, whose filename needs to be passed as
+# the module_or_package argument.
+#from-stdin=
+
+# Files or directories to be skipped. They should be base names, not paths.
+ignore=CVS,protobufs
+
+# Add files or directories matching the regular expressions patterns to the
+# ignore-list. The regex matches against paths and can be in Posix or Windows
+# format. Because '\\' represents the directory delimiter on Windows systems,
+# it can't be used as an escape character.
+# ignore-paths=
+
+# Files or directories matching the regular expression patterns are skipped.
+# The regex matches against base names, not paths. The default value ignores
+# Emacs file locks
+ignore-patterns=^\.#
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis). It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
+# number of processors available to use, and will cap the count on Windows to
+# avoid hangs.
+jobs=1
+
+# Control the amount of potential inferred values when inferring a single
+# object. This can help the performance when dealing with large functions or
+# complex, nested conditions.
+limit-inference-results=100
+
+# List of plugins (as comma separated values of python module names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Minimum Python version to use for version dependent checks. Will default to
+# the version used to run pylint.
+py-version=3.9
+
+# Discover python modules and packages in the file system subtree.
+recursive=no
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages.
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+# In verbose mode, extra non-checker-related info will be displayed.
+#verbose=
+
+
+[BASIC]
+
+# Naming style matching correct argument names.
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style. If left empty, argument names will be checked with the set
+# naming style.
+#argument-rgx=
+
+# Naming style matching correct attribute names.
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style. If left empty, attribute names will be checked with the set naming
+# style.
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma.
+bad-names=foo,
+ bar,
+ baz,
+ toto,
+ tutu,
+ tata
+
+# Bad variable names regexes, separated by a comma. If names match any regex,
+# they will always be refused
+bad-names-rgxs=
+
+# Naming style matching correct class attribute names.
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style. If left empty, class attribute names will be checked
+# with the set naming style.
+#class-attribute-rgx=
+
+# Naming style matching correct class constant names.
+class-const-naming-style=UPPER_CASE
+
+# Regular expression matching correct class constant names. Overrides class-
+# const-naming-style. If left empty, class constant names will be checked with
+# the set naming style.
+#class-const-rgx=
+
+# Naming style matching correct class names.
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-
+# style. If left empty, class names will be checked with the set naming style.
+#class-rgx=
+
+# Naming style matching correct constant names.
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style. If left empty, constant names will be checked with the set naming
+# style.
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names.
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style. If left empty, function names will be checked with the set
+# naming style.
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma.
+good-names=i,
+ j,
+ k,
+ ex,
+ Run,
+ _
+
+# Good variable names regexes, separated by a comma. If names match any regex,
+# they will always be accepted
+good-names-rgxs=
+
+# Include a hint for the correct naming format with invalid-name.
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names.
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style. If left empty, inline iteration names will be checked
+# with the set naming style.
+#inlinevar-rgx=
+
+# Naming style matching correct method names.
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style. If left empty, method names will be checked with the set naming style.
+#method-rgx=
+
+# Naming style matching correct module names.
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style. If left empty, module names will be checked with the set naming style.
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+# These decorators are taken in consideration only for invalid-name.
+property-classes=abc.abstractproperty
+
+# Regular expression matching correct type variable names. If left empty, type
+# variable names will be checked with the set naming style.
+#typevar-rgx=
+
+# Naming style matching correct variable names.
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style. If left empty, variable names will be checked with the set
+# naming style.
+#variable-rgx=
+
+
+[CLASSES]
+
+# Warn about protected attribute access inside special methods
+check-protected-access-in-special-methods=no
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+ __new__,
+ setUp,
+ __post_init__
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+ _fields,
+ _replace,
+ _source,
+ _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[DESIGN]
+
+# List of regular expressions of class ancestor names to ignore when counting
+# public methods (see R0903)
+exclude-too-few-public-methods=
+
+# List of qualified class names to ignore when counting class parents (see
+# R0901)
+ignored-parents=
+
+# Maximum number of arguments for function / method.
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in an if statement (see R0916).
+max-bool-expr=5
+
+# Maximum number of branch for function / method body.
+max-branches=12
+
+# Maximum number of locals for function / method body.
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body.
+max-returns=6
+
+# Maximum number of statements in function / method body.
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when caught.
+overgeneral-exceptions=builtins.BaseException,builtins.Exception
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )??$
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+# Maximum number of characters on a single line.
+max-line-length=100
+
+# Maximum number of lines in a module.
+max-module-lines=1100
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[IMPORTS]
+
+# List of modules that can be imported at any level, not just the top level
+# one.
+allow-any-import-level=
+
+# Allow explicit reexports by alias from a package __init__.
+allow-reexport-from-package=no
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Deprecated modules which should not be used, separated by a comma.
+deprecated-modules=
+
+# Output a graph (.gv or any supported image format) of external dependencies
+# to the given file (report RP0402 must not be disabled).
+ext-import-graph=
+
+# Output a graph (.gv or any supported image format) of all (i.e. internal and
+# external) dependencies to the given file (report RP0402 must not be
+# disabled).
+import-graph=
+
+# Output a graph (.gv or any supported image format) of internal dependencies
+# to the given file (report RP0402 must not be disabled).
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+# Couples of modules and preferred modules, separated by a comma.
+preferred-modules=
+
+
+[LOGGING]
+
+# The type of string formatting that logging methods do. `old` means using %
+# formatting, `new` is for `{}` formatting.
+logging-format-style=old
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format.
+logging-modules=logging
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
+# UNDEFINED.
+confidence=HIGH,
+ CONTROL_FLOW,
+ INFERENCE,
+ INFERENCE_FAILURE,
+ UNDEFINED
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once). You can also use "--disable=all" to
+# disable everything first and then re-enable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use "--disable=all --enable=classes
+# --disable=W".
+disable=raw-checker-failed,
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ # Added messages
+ use-symbolic-message-instead,
+ invalid-name,
+ missing-class-docstring,
+ missing-module-docstring,
+ missing-function-docstring,
+ consider-using-f-string,
+ inconsistent-return-statements,
+ no-member,
+ too-many-arguments,
+ too-many-locals,
+ too-many-branches,
+ too-many-statements,
+ cyclic-import,
+ too-few-public-methods,
+ protected-access,
+ fixme,
+ logging-format-interpolation,
+ logging-too-many-args,
+ attribute-defined-outside-init,
+ abstract-method,
+ pointless-statement,
+ wrong-import-order,
+ duplicate-code,
+ unbalanced-tuple-unpacking,
+ unused-argument
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=c-extension-no-member
+
+
+[METHOD_ARGS]
+
+# List of qualified names (i.e., library.method) which require a timeout
+# parameter e.g. 'requests.api.get,requests.api.post'
+timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+ XXX,
+ TODO
+
+# Regular expression of note tags to take in consideration.
+notes-rgx=
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=sys.exit,argparse.parse_error
+
+
+[REPORTS]
+
+# Python expression which should return a score less than or equal to 10. You
+# have access to the variables 'fatal', 'error', 'warning', 'refactor',
+# 'convention', and 'info' which contain the number of messages in each
+# category, as well as 'statement' which is the total number of statements
+# analyzed. This score is used by the global evaluation report (RP0004).
+evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details.
+msg-template=
+
+# Set the output format. Available formats are text, parseable, colorized, json
+# and msvs (visual studio). You can also give a reporter class, e.g.
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages.
+reports=yes
+
+# Activate the evaluation score.
+score=yes
+
+
+[SIMILARITIES]
+
+# Comments are removed from the similarity computation
+ignore-comments=yes
+
+# Docstrings are removed from the similarity computation
+ignore-docstrings=yes
+
+# Imports are removed from the similarity computation
+ignore-imports=yes
+
+# Signatures are removed from the similarity computation
+ignore-signatures=yes
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes.
+max-spelling-suggestions=4
+
+# Spelling dictionary name. Available dictionaries: none. To make it work,
+# install the 'python-enchant' package.
+spelling-dict=
+
+# List of comma separated words that should be considered directives if they
+# appear at the beginning of a comment and should not be checked.
+spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains the private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to the private dictionary (see the
+# --spelling-private-dict-file option) instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[STRING]
+
+# This flag controls whether inconsistent-quotes generates a warning when the
+# character used as a quote delimiter is used inconsistently within a module.
+check-quote-consistency=no
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=no
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=
+
+# Tells whether to warn about missing members when the owner of the attribute
+# is inferred to be None.
+ignore-none=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of symbolic message names to ignore for Mixin members.
+ignored-checks-for-mixins=no-member,
+ not-async-context-manager,
+ not-context-manager,
+ attribute-defined-outside-init
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+# Regex pattern to define which classes are considered mixins.
+mixin-class-rgx=.*[Mm]ixin
+
+# List of decorators that change the signature of a decorated function.
+signature-mutators=
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid defining new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of names allowed to shadow builtins
+allowed-redefined-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+ _cb
+
+# A regular expression matching the name of dummy variables (i.e. expected to
+# not be used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored.
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
\ No newline at end of file
diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md
new file mode 100644
index 00000000..3f938096
--- /dev/null
+++ b/plugins/accelerated-moe/README.md
@@ -0,0 +1,68 @@
+# FMS Acceleration for Mixture-of-Experts
+
+This library contains plugins to accelerate finetuning with the following optimizations:
+1. Expert-Parallel MoE with Megablocks
+
+## Plugins
+
+Plugin | Description | Depends | Loading | Augmentation | Callbacks
+--|--|--|--|--|--
+[megablocks](./src/fms_acceleration_moe/framework_plugin_megablocks.py) | MoE Expert Parallel with megablocks | megablocks | ✅ | | ✅
+
+
+## Running Benchmarks
+
+See the benchmarks [a100_80gb_mb.csv](../../scripts/benchmarks/refs/a100_80gb_mb.csv)
+
+
+Run the below in the top-level directory of this repo:
+- the `megablocks` dep is not included by default, so the `-x` switch installs it.
+
+```
+tox -e run-benches \
+ -x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-mb.txt" \
+ -- \
+ 8 8 benchmark_outputs scenarios.yaml accelerated-moe-megablocks
+
+```
+
+NOTE: if `FileNotFoundError` is observed on the *triton cache*, similar to issues like these:
+- https://github.com/triton-lang/triton/issues/2688
+
+then somehow `tox` is causing problems with triton and multiprocessing (there is some race condition).
+But the workaound is to first *activate the tox env* and
+running in `bash`:
+```
+# if FileNotFoundError in the triton cache is observed
+# - then activate the env and run the script manually
+
+source .tox/run-benches/bin/activate
+bash scripts/run_benchmarks.sh \
+ 8 8 benchmark_outputs scenarios.yaml accelerated-moe-megablocks
+```
+
+
+## Expert-Parallel MoE with Megablocks
+
+Currently supports *mixed precision*. Will upcast the router and the sharded experts if turned on.
+- However this is hard-coded to off at the moment.
+- The FSDP mixed precision works independenly of the MoE one.
+
+Not all of the features of `megablocks` are being incorporated; listing down some of the restrictions of the current integration:
+- currently not passing the data parallel `dp_mesh` to the `FSDP` constructor, so `FSDP` will always shard over the default process group (over world_size).
+- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed.
+- only supports the *dropless sparse* MLPs in the megablocks package; the other variations like non-dropless and grouped computes are not currently integrated.
+- the `shard_moe` may not scale well with larger models as the current implementation `torch.concat` all the expert weights together before passing to `torch.distributed` to be sharded. This is redundently done in all devices, so it is inefficient.
+- currently only supports `StateDictType.SHARDED_STATE_DICT` because the implementation uses `DTensors` which have limited support for full state dicts. However for efficiency considerations, sharded state dicts are the most efficient.
+
+### Megablocks Dependencies
+
+Currently databricks megablocks does not have a PyPi repository and no proper release, so we have to install directly from Github, refer to instructions below.
+- This has to be a manual install as PyPI will complain if included as an official plugin dependency.
+- Since this is not a binary install, please note that CUDA Toolkit will be required to build some of the kernels used by megablocks.
+
+```
+# this will install the megablocks from Github
+# megablocks requires CUDA Toolkit to build.
+pip install -r requirements-mb.txt
+```
\ No newline at end of file
diff --git a/plugins/accelerated-moe/configs/megablocks.yaml b/plugins/accelerated-moe/configs/megablocks.yaml
new file mode 100644
index 00000000..cd63a21c
--- /dev/null
+++ b/plugins/accelerated-moe/configs/megablocks.yaml
@@ -0,0 +1,40 @@
+training:
+
+ # mixture-of-experts configurations
+ moe:
+
+ # expert-parallel for MoE
+ megablocks:
+
+ # The name of the mixture-of-experts class
+ moe_component_class: MixtralSparseMoeBlock
+
+ # The module name of the router in moe_component_class above
+ moe_gate_module_name: gate
+
+ # The module name of the experts in moe_component_class above
+ moe_experts_module_name: experts
+
+ # the mlp version
+ # - for those with only up and down projs, use "v1"
+ # - for those with only up, down and gate projs, use "v2"
+ moe_mlp_impl: v2
+
+ # if True, then we shard experts across data parallel dimension
+ # - only feasible if world_size divides the number of experts
+ shard_along_dp: true
+
+ # to be specified only if shard_along_dp == False. This will influence
+ # the level of sharding, which indicates how many experts per device
+ # - the number of experts per device will be num_experts / ep_size
+ # - we disable the ability to set ep_size=1 since this means no sharding
+ # - NOTE: ep_size=1 does not mean shard_along_dp=True, which would otherwise
+ # be contradictory since ep_size suggests no expert parallel.
+ # ep_size: 2
+
+ # the MoE dropless implementation. Currently we only support "dropless_sparse", but
+ # in the future we may support others
+ moe_implementation: dropless_sparse
+
+ # for load_balancing_loss
+ load_balancing_loss: false
diff --git a/plugins/accelerated-moe/pyproject.toml b/plugins/accelerated-moe/pyproject.toml
new file mode 100644
index 00000000..b100bc1e
--- /dev/null
+++ b/plugins/accelerated-moe/pyproject.toml
@@ -0,0 +1,29 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project]
+name = "fms-acceleration-moe"
+version = '0.0.1'
+description = "FMS Acceleration Plugin for Mixture-of-Experts"
+authors = [
+ {name = "Fabian Lim", email = "flim@sg.ibm.com"},
+]
+license = {text = "Apache-2.0"}
+readme = "README.md"
+requires-python = "~=3.9"
+keywords = ['fms-hf-tuning', 'acceleration', 'mixture-of-experts', 'megablocks']
+classifiers=[
+ "License :: OSI Approved :: Apache Software License",
+ "Development Status :: 4 - Beta",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+]
+
+[tool.hatch.build.targets.wheel]
+only-include = ["src/fms_acceleration_moe"]
+
+[tool.hatch.build.targets.wheel.sources]
+"src" = ""
diff --git a/plugins/accelerated-moe/requirements-mb.txt b/plugins/accelerated-moe/requirements-mb.txt
new file mode 100644
index 00000000..7875fc98
--- /dev/null
+++ b/plugins/accelerated-moe/requirements-mb.txt
@@ -0,0 +1,3 @@
+megablocks @ git+https://github.com/databricks/megablocks.git@bce5d7b2aaf5038bc93b36f76c2baf51c2939bd2
+
+# auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git@ea829c7bbe83561c2b1de26795b6592992373ef7
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py
new file mode 100644
index 00000000..7a459ffc
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py
@@ -0,0 +1,17 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Local
+from .framework_plugin_megablocks import MegablocksMoEAccelerationPlugin
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py
new file mode 100644
index 00000000..f9dcc60a
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py
@@ -0,0 +1,196 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from typing import Dict
+import warnings
+
+# Third Party
+from fms_acceleration import AccelerationPlugin
+from transformers import AutoConfig, AutoModelForCausalLM
+import torch
+
+
+# pylint: disable=too-many-instance-attributes
+class MegablocksMoEAccelerationPlugin(AccelerationPlugin):
+
+ require_packages = {"megablocks"}
+
+ def __init__(self, configurations: Dict[str, Dict]):
+ super().__init__(configurations)
+
+ # arguments for configuring the mixture-of-experts model with defaults
+ # shown below for Mixtral 7x8b
+ # - 1. component class
+ self._moe_component_cls = self._check_config_and_maybe_check_values(
+ key="training.moe.megablocks.moe_component_class",
+ default="MixtralSparseMoeBlock",
+ )
+ # - 2. gate_module_name
+ self._gate_module_name = self._check_config_and_maybe_check_values(
+ key="training.moe.megablocks.moe_gate_module_name", default="gate"
+ )
+ # - 3. experts_module_name
+ self._experts_module_name = self._check_config_and_maybe_check_values(
+ key="training.moe.megablocks.moe_experts_module_name", default="experts"
+ )
+ # - 4. mlp version
+ self._mlp_version = self._check_config_and_maybe_check_values(
+ key="training.moe.megablocks.moe_mlp_impl",
+ values=["v1", "v2"],
+ default="v2",
+ )
+
+ # for controlling the type of sharding
+ self._shard_along_dp = self._check_config_and_maybe_check_values(
+ key="training.moe.megablocks.shard_along_dp",
+ values=[True, False],
+ default=True,
+ )
+
+ # ep_size determines the expert parallel sharding
+ # - ep_size is ignored if _shard_along_dp=True
+ self._ep_size = None
+ if not self._shard_along_dp:
+ self._ep_size = self._check_config_and_maybe_check_values(
+ key="training.moe.megablocks.ep_size",
+ default=1,
+ )
+
+ # for the moe_implementation, currently we only use the megablocks
+ # dropless sparse implementation
+ self._moe_implementation = self._check_config_and_maybe_check_values(
+ key="training.moe.megablocks.moe_implementation",
+ values=["dropless_sparse"],
+ default="dropless_sparse",
+ )
+ self._moe_implementation = self._moe_implementation.split("_")[1]
+
+ self._load_balancing_loss = self._check_config_and_maybe_check_values(
+ key="training.moe.megablocks.load_balancing_loss",
+ values=[True, False],
+ default=False,
+ )
+
+ @property
+ def requires_custom_loading(self):
+ return True
+
+ def model_loader(self, model_name: str, **kwargs):
+ # guarded
+ # Local
+ # pylint: disable=import-outside-toplevel
+ from .megablocks_utils.config_utils import update_mlp_registry
+ from .megablocks_utils.shard_moe_utils import get_moe_kwargs, shard_moe
+
+ # - check the config
+ if self._load_balancing_loss and not hasattr(
+ AutoConfig.from_pretrained(model_name), "output_router_logits"
+ ):
+ warnings.warn(
+ "load_balancing_loss=True but "
+ "the model '{model_name}' config not have 'output_router_logits' "
+ "in its config, hence it might not support load balancing and "
+ "fallback to load_balancing_loss=False."
+ )
+ self._load_balancing_loss = False
+
+ # this one does a forward patching on MLP, but needs to be fixed
+ # properly as the load balancing loss is currently not properly
+ # handled
+ update_mlp_registry(
+ self._moe_implementation, self._mlp_version, self._load_balancing_loss
+ )
+
+ # get additional parameters
+ torch_dtype = kwargs.get("torch_dtype", torch.float32)
+
+ # load the model
+ model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
+
+ # set this in the config, which will be picked up by the forward
+ # function to go into the load_balancing loss
+ model.config.output_router_logits = self._load_balancing_loss
+
+ rank, world_size = 0, 1
+ if torch.distributed.is_initialized():
+ world_size = torch.distributed.get_world_size()
+ rank = torch.distributed.get_rank()
+ else:
+ # NOTE: or should we do a silent fallback
+ raise AssertionError(
+ "Megablocks expert parallel only works for distributed training."
+ )
+
+ # shard the MOE, and store products required for
+ # FSDP configuration
+ # pylint: disable=unused-variable
+ (dp_mesh, self._moe_component_module_names) = shard_moe(
+ model,
+ self._moe_component_cls,
+ checkpoint_name_or_path=model_name,
+ rank=rank,
+ world_size=world_size,
+ ep_size=self._ep_size,
+ moe_kwargs=get_moe_kwargs(
+ model.config,
+ fp16=torch_dtype == torch.float16,
+ bf16=torch_dtype == torch.bfloat16,
+ ),
+ shared_mesh_dim=self._shard_along_dp,
+ router_name=self._gate_module_name,
+ expert_name=self._experts_module_name,
+ mixed_precision=False, # Currently this is hardcoded to OFF
+ )
+
+ # NOTE: there is currently no good way to get the mixed precision
+ # flag from train_args. It will be better to handle this if
+ # when we move the sharding to augmentation.
+
+ # NOTE: Currently, it is a bit troublesome to pass the device_mesh to
+ # the FSDP constructor, so we do not do that.
+ # - therefore FSDP will always shard on world_size over the default process
+ # group
+
+ return model
+
+ def get_callbacks_and_ready_for_train(
+ self, model: torch.nn.Module = None, accelerator=None
+ ):
+
+ callbacks = []
+ if (
+ accelerator is not None
+ and getattr(accelerator.state, "fsdp_plugin", None) is not None
+ ):
+ # - use an internal function call to get the no split
+ # module names, which are typically layers
+ _layers = model._get_no_split_modules("")
+ accelerator.state.fsdp_plugin.ignored_modules = [
+ getattr(layer, name)
+ for name in self._moe_component_module_names
+ for layer in model.modules()
+ if layer.__class__.__name__ in _layers
+ ]
+
+ return callbacks
+
+
+# register
+AccelerationPlugin.register_plugin(
+ MegablocksMoEAccelerationPlugin,
+ configuration_and_paths=[
+ "training.moe.megablocks",
+ ],
+)
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/__init__.py
new file mode 100644
index 00000000..38a9531e
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/__init__.py
@@ -0,0 +1,13 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py
new file mode 100644
index 00000000..d4363e7a
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py
@@ -0,0 +1,152 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+import os
+
+# Third Party
+from accelerate.logging import get_logger
+from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME
+from torch.distributed.checkpoint.default_planner import (
+ DefaultLoadPlanner,
+ DefaultSavePlanner,
+)
+from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
+from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
+import torch.distributed.checkpoint as dcp
+
+logger = get_logger(__name__)
+
+# - variable to capture the model variable
+# in the save/load model calls
+MODEL_INDEX = None
+
+# Below are rewrite of functions for megablocks
+
+
+# rewrite of func from accelerate.utils.fsdp_utils.py
+# - empty function, as main logic is in the optimizer call
+# save_fsdp_optimizer (see below).
+def save_fsdp_model(
+ fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False
+):
+ # pylint: disable=global-statement
+ global MODEL_INDEX
+ MODEL_INDEX = model_index
+
+
+# rewrite of func from accelerate.utils.fsdp_utils.py
+# - saves both model and optimizer
+def save_fsdp_optimizer(
+ fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0
+):
+
+ if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT:
+ raise NotImplementedError(
+ "Checkpointing for megablocks only enabled for sharded state dict."
+ )
+
+ # get the state dicts for model and optimize
+ (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer)
+
+ # - save model
+ ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
+ os.makedirs(ckpt_model, exist_ok=True)
+ logger.info(f"Saving model to {ckpt_model}")
+ dcp.save(
+ state_dict={"model": model_state_dict},
+ storage_writer=dcp.FileSystemWriter(ckpt_model),
+ planner=DefaultSavePlanner(),
+ )
+ logger.info(f"Model saved to {ckpt_model}")
+
+ # - save optimizer
+ ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
+ os.makedirs(ckpt_opt, exist_ok=True)
+ logger.info(f"Saving Optimizer state to {ckpt_opt}")
+ dcp.save(
+ state_dict={"optimizer": optimizer_state_dict},
+ storage_writer=dcp.FileSystemWriter(ckpt_opt),
+ planner=DefaultSavePlanner(),
+ )
+ logger.info(f"Optimizer state saved in {ckpt_opt}")
+
+
+# rewrite of func from accelerate.utils.fsdp_utils.py
+# - empty function, as main logic is in the optimizer call
+# load_fsdp_optimizer (see below).
+def load_fsdp_model(
+ fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False
+):
+ # pylint: disable=global-statement
+ global MODEL_INDEX
+ MODEL_INDEX = model_index
+
+
+# rewrite of func from accelerate.utils.fsdp_utils.py
+# - loads both model and optimizer
+def load_fsdp_optimizer(
+ fsdp_plugin,
+ accelerator,
+ optimizer,
+ model,
+ input_dir,
+ optimizer_index=0,
+ adapter_only=False,
+):
+
+ accelerator.wait_for_everyone()
+ if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT:
+ raise NotImplementedError(
+ "Checkpointing for megablocks only enabled for sharded state dict."
+ )
+
+ # - get the state dicts
+ model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
+
+ # - load the model state dict
+ ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
+ dcp.load(
+ state_dict={"model": model_state_dict},
+ storage_reader=dcp.FileSystemReader(ckpt_model),
+ planner=DefaultLoadPlanner(),
+ )
+
+ # - load the optimizer state dict
+ ckpt_opt = os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
+ dcp.load(
+ state_dict={"optimizer": optimizer_state_dict},
+ storage_reader=dcp.FileSystemReader(ckpt_opt),
+ planner=DefaultLoadPlanner(),
+ )
+
+ # - set the state dicts
+ set_state_dict(
+ model,
+ optimizer,
+ model_state_dict=model_state_dict,
+ optim_state_dict=optimizer_state_dict,
+ )
+
+ # HACK for now
+ # - if seems that if params is empty, then the loading has someo
+ # problems
+ # - so for now, we just dump some random defaults
+ for group in optimizer.param_groups:
+ if len(group["params"]) == 0:
+ group["betas"] = (0.9, 0.999)
+ group["lr"] = 0.0
+ group["initial_lr"] = 0.0
+ group["eps"] = 1e-8
+ group["weight_decay"] = 0.0
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py
new file mode 100644
index 00000000..1234d383
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py
@@ -0,0 +1,133 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# utilities to update megablocks to register various things
+# e.g, the MLP_v2 that handles gate, up, down projections
+
+# Third Party
+import torch
+import torch.nn.functional as F
+
+
+# this function ensures that the megablocks packaged is configured to use
+# the correct SparseMLP implementation
+# - at the moment not considering the GroupedMLP implementations
+def update_mlp_registry(
+ mlp_type: str = "sparse",
+ mlp_version: str = "v2",
+ load_balancing_loss: bool = False,
+):
+ # guarded
+ # Third Party
+ # pylint: disable=import-error,import-outside-toplevel
+ from megablocks.layers.dmlp_registry import _REGISTRY
+ from megablocks.layers.mlp import SparseMLP, resolve_dtensor
+ from megablocks.layers.moe import ParallelMLP
+ from megablocks.layers.router import LearnedRouter, _uniform_expert_assignment
+
+ # Local
+ from .sparse_mlp2 import SparseMLPv2
+
+ SPARSE_MLP_IMPL = {
+ "v1": SparseMLP,
+ "v2": SparseMLPv2,
+ }
+
+ # replace the registry to point to the the correct sparse implementation
+ if mlp_type == "sparse":
+ assert (
+ mlp_version in SPARSE_MLP_IMPL
+ ), f"Megablocks only support sparse mlp versions: {','.join(SPARSE_MLP_IMPL.keys())}"
+ _REGISTRY["mlp"]["sparse"] = SPARSE_MLP_IMPL[mlp_version]
+ else:
+ raise NotImplementedError("Currently only supports sparse MLP implementations.")
+
+ def forward(self, x, scores, expert_weights, top_experts):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, _ = self.forward_fn(x, expert_weights, top_experts)
+
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+
+ # in this case we should be returning the router
+ # logits out of the MoE forward.
+ if load_balancing_loss:
+ return x, torch.log(scores)
+
+ # otherwise just return None
+ return x, None
+
+ # replace the forward function. Willing to do this because ParallelMLP
+ # is only used here and not anywhere else, hence:
+ # 1. we do not care about reversing the patch
+ # 2. we have control on where this is called, and we know to call it
+ # before our code accesses this function. Hence, we view this as
+ # a hardcoded modification to the megablocks package more than a
+ # patch.
+ ParallelMLP.forward = forward
+
+ # for the router
+ # - need to resolve the dtensor since we had replicated the router
+ # weights
+ def forward_router(self, x):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ _weight = resolve_dtensor(self.layer.weight)
+ _bias = None if self.layer.bias is None else resolve_dtensor(self.layer.bias)
+ # pylint: disable=not-callable
+ scores = F.linear(x.view(-1, x.shape[-1]), _weight, _bias).softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ )
+ if self.args.uniform_expert_assignment
+ else expert_indices
+ )
+ return scores, expert_weights, expert_indices
+
+ # replace the forward function in the router
+ # - same as above
+ LearnedRouter.forward = forward_router
+
+ # Third Party
+ from fms_acceleration.model_patcher import patch_target_module
+
+ # Local
+ from .checkpoint_utils import (
+ load_fsdp_model,
+ load_fsdp_optimizer,
+ save_fsdp_model,
+ save_fsdp_optimizer,
+ )
+
+ patch_target_module("transformers.trainer.save_fsdp_model", save_fsdp_model)
+ patch_target_module("transformers.trainer.save_fsdp_optimizer", save_fsdp_optimizer)
+ patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model)
+ patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer)
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py
new file mode 100644
index 00000000..8ef95fa2
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py
@@ -0,0 +1,431 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from collections import defaultdict
+from contextlib import ExitStack
+from copy import copy
+from typing import Dict, List, Tuple, Type, Union
+import json
+import os
+import re
+import warnings
+
+# Third Party
+from accelerate import init_empty_weights
+from safetensors import safe_open
+from torch.distributed._tensor import Placement, Replicate, Shard, distribute_tensor
+from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
+from tqdm import tqdm
+from transformers import PretrainedConfig
+from transformers.activations import ACT2FN
+import torch
+
+FILE_SAFETENSOR_INDEX = "model.safetensors.index.json"
+KEY_DATA_PARALLEL = "data_parallel"
+KEY_EXPERT_PARALLEL = "expert_parallel"
+DIM_EXPERT = 0
+
+# these depend on the namings in the dMOE
+KEY_DMOE_ROUTER = "router.layer.weight"
+KEY_DMOE_EXPERTS = "experts.mlp"
+
+
+def get_moe_kwargs(
+ config: PretrainedConfig,
+ fp16: bool = False,
+ bf16: bool = False,
+):
+ return {
+ "hidden_size": config.hidden_size,
+ "ffn_hidden_size": config.intermediate_size,
+ "moe_num_experts": config.num_local_experts,
+ "moe_top_k": config.num_experts_per_tok,
+ "moe_expert_model_parallelism": True,
+ "memory_optimized_mlp": False,
+ "activation_fn": ACT2FN[config.hidden_act],
+ "moe_normalize_expert_weights": True,
+ "return_bias": False,
+ "fp16": fp16,
+ "bf16": bf16,
+ }
+
+
+# trick to get the resolved cache file to acccess the safetensor
+# NOTE: this does not work if _dict_from_json_file, like GGUF files
+def get_resolved_checkpoint_location(model_name_or_path: str):
+
+ result = None
+ _old_func = PretrainedConfig._dict_from_json_file
+
+ def _dict_from_json_file(resolved_config_file):
+ nonlocal result
+ result = resolved_config_file
+ return _old_func(resolved_config_file)
+
+ # make a hook and restrive
+ PretrainedConfig._dict_from_json_file = _dict_from_json_file
+ PretrainedConfig.from_pretrained(model_name_or_path)
+ PretrainedConfig._dict_from_json_file = _old_func
+ return os.path.dirname(result)
+
+
+# This function creates a dictionary of keys and paths into the the sharded
+# safetensors checkpoint file, that are relevant to the "prefix" and "instance_name"
+# being pased in.
+# - the keys point to modules found in megablocks.layers.dmoe.dMoE, the distributed
+# expert module provided by megablocks.
+# - the values are tuples pointing to the keys within the checkpoint file.
+#
+# Example: if prefix="module.layers.0" and instance_name="block_sparse_moe", then a dictionary
+# of the following will be returned:
+# {
+# 'experts.mlp.w1': [
+# (
+# 'model.layers.0.block_sparse_moe.experts.0.w1.weight',
+# 'model-00001-of-00019.safetensors'
+# ),
+# (
+# 'model.layers.0.block_sparse_moe.experts.1.w1.weight',
+# 'model-00001-of-00019.safetensors'
+# ),
+# ...
+# ]
+# 'experts.mlp.w2': [...],
+# 'experts.mlp.w3': [...],
+# 'router.layer.weight': [
+# (
+# 'model.layers.0.block_sparse_moe.gate.weight',
+# 'model-00001-of-00019.safetensors'
+# )
+# ]
+# }
+def get_checkpoint_meta_from_sharded_safetensor(
+ weight_map: Dict,
+ prefix: str, # e.g., 'model.layers.0,
+ instance_name: str, # e.g., block_sparse_moe
+ router_name: str = "gate", # e.g., named "gate" within block_sparse_moe
+ expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe
+) -> Dict[str, List[Tuple]]:
+ # insert in order
+ def _insert(L: List, i: int, v):
+ n = len(L)
+ if i < n:
+ L[i] = v
+ return
+
+ n = i - n + 1
+ while n > 0:
+ L.append(None)
+ n -= 1
+ L[i] = v
+
+ # state dict -> weights
+ # 'router.layer.weight': [(k, file),...]
+ # `experts.mlp.w1`: [...]
+ _map = defaultdict(list)
+ prefix = f"{prefix}.{instance_name}."
+ for k, stfile in weight_map.items():
+ if not k.startswith(prefix):
+ continue
+
+ # e.g. after replacement we get
+ # - gate.weight
+ # - experts.0.w1.weight
+ rel_k = k.replace(prefix, "")
+ # pylint: disable=anomalous-backslash-in-string
+ m = re.match(f"({router_name}|{expert_name})\.?(\d+)?\.?(\w+)?\.weight", rel_k)
+ if m is None:
+ raise ValueError(
+ f"Unable to handle key '{k}' with provided router_name "
+ f"'{router_name}' or expert_name '{expert_name}'"
+ )
+ if m.group(1) == router_name:
+ _map[KEY_DMOE_ROUTER].append((k, stfile))
+ elif m.group(1) == expert_name:
+ index = int(m.group(2))
+ mod = m.group(3)
+ _insert(_map[f"{KEY_DMOE_EXPERTS}.{mod}"], index, (k, stfile))
+
+ if len(_map) == 0:
+ raise ValueError(
+ f"Could not get safetensor map for '{prefix}' and '{instance_name}'"
+ )
+
+ return _map
+
+
+# this function will load the sharded experts onto the device.
+# - this assumes that the "dmoe" module is the megablocks.layers.dmoe.dMoE distributed
+# implementation of the mixture of experts.
+def load_sharded_experts_onto_device(
+ dmoe: torch.nn.Module,
+ directory: str,
+ checkpoint_metadata: Dict[str, List[Tuple]],
+ device_mesh: DeviceMesh,
+ placements: Placement,
+ expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe
+ mixed_precision: bool = False,
+):
+ # typically they all should be same file, but to play safe, load the checkpoint file onto
+ # cpu first since we may not need all weights in that file.
+ with ExitStack() as stack:
+ files = {}
+ for _, vs in checkpoint_metadata.items():
+ for _, fi in vs:
+ if fi not in files:
+ files[fi] = stack.enter_context(
+ safe_open(
+ os.path.join(directory, fi), framework="pt", device="cpu"
+ )
+ )
+
+ # go by one weight at a time.
+ # - weight_name: points to megablocks.dmoe
+ upcasted = set()
+ for weight_name, vs in checkpoint_metadata.items():
+ data = []
+ for k, fi in vs:
+ T = files[fi].get_tensor(k)
+ if expert_name in k and k.endswith("weight"):
+ if T.shape[1] > T.shape[0]:
+ T = T.t()
+ data.append(T)
+
+ # get the module we want to shard
+ name = weight_name.split(".")
+ path, name = ".".join(name[:-1]), name[-1]
+ mod = dmoe.get_submodule(path)
+
+ # if mixed_precision and KEY_DMOE_ROUTER not in weight_name:
+ if mixed_precision:
+ mod_dtype = torch.float32
+ upcasted.add(weight_name)
+ else:
+ mod_dtype = getattr(mod, name).dtype
+
+ requires_grad = getattr(mod, name).requires_grad
+
+ # the megablocks dmoe experts the expert features to be on DIM_EXPERT.
+ # - concat on dim 0 and distribute
+ # - cast to the correct dtype for the module
+ # - if mixed precision is enabled, then sharded params are cased
+ param = torch.concat(data, dim=DIM_EXPERT).to(mod_dtype)
+
+ _placements = placements
+ if KEY_DMOE_ROUTER in weight_name:
+ # - the router needs to be replicated
+ _placements = [Replicate() for _ in range(len(placements))]
+
+ param = torch.nn.Parameter(
+ distribute_tensor(param, device_mesh, _placements),
+ requires_grad=requires_grad,
+ )
+
+ # register the sharded parameter onto the megablocks.dmoe
+ mod.register_parameter(name, param)
+
+ if mixed_precision:
+ upcasted = ", ".join(sorted(upcasted))
+ warnings.warn(f"Mixed precision turned on, upcasted MoE parameters: {upcasted}")
+
+
+def shard_moe(
+ model: torch.nn.Module,
+ moe_cls: Union[str, Type],
+ checkpoint_name_or_path: str,
+ rank: int,
+ world_size: int,
+ moe_kwargs: Dict,
+ device_type: str = "cuda",
+ key_dp: str = KEY_DATA_PARALLEL,
+ key_ep: str = KEY_EXPERT_PARALLEL,
+ router_name: str = "gate",
+ expert_name: str = "experts",
+ shared_mesh_dim: bool = True,
+ ep_size: int = 1,
+ mixed_precision: bool = False,
+):
+ """shard_moe takes a mixture-of-experts huggingface model and shards the experts
+ on the current device. All layers layers that have a MoE module will be sharded.
+
+ The function requires "checkpoint_name_or_path" to point to the checkpoint that
+ the model has been loaded from, because model could have been loaded on the meta
+ device, and in which case would be missing the weights. This function will
+ instialize the sharded weights onto the device.
+
+ The sharding has two modes, and depends on world_size and number_of_experts the model
+ has. This depends on the setting "shared_mesh_dim" to True or False:
+ - if True: then dp and ep will happen on the same device_mesh dimension.
+ This is only possible if world_size divides number_of_experts
+ (which requires world_size < num_of_experts).
+ - if False: then dp and ep will be seperate device_mesh dimensions. The ep_size will be
+ determined by the argument passed in (which needs to be properly set ep_size > 1;
+ the default value will raise an assertion).
+
+ Parameters:
+
+ model (module): A valid mixture-of-experts Huggingface model.
+ moe_cls (str,type): A module class used to identify the MoE components.
+ checkpoint_name_or_path (str): name or path pointing to the weight checkpoint.
+ rank (int): rank of the current device.
+ world_size (int): total world size.
+ moe_kwargs (dict): kwargs to be passed to construct megablocks.layers.arguments for
+ constructing the megablocks.layer.dmoe.dMOE.
+ device_type (str): the current device to load the sharded model into.
+ key_dp (str): name of the data parallel mesh
+ key_ep (str): name of the expert parallel mesh (if initialized).
+ router_name (str): module name of the router in moe_cls (e.g., "gate").
+ expert_name (str): module name of the experts in moe_cls (e.g., "experts").
+ shared_mesh_dim (bool): for the sharding mode, see explanation above.
+ ep_size (int): for shard_mesh_dim=False only, see explanation above.
+ mixed_precision (bool): activate mixed precision and upcasts sharded params
+
+ """
+ # guarded import
+ # Third Party
+ # pylint: disable=import-error, import-outside-toplevel
+ from megablocks.layers import arguments, dmoe
+
+ if shared_mesh_dim:
+ # if sharing mesh with dp, then the ep_size must be the world_size
+ # - in this case ep_shard_factor is ignored
+ ep_size = world_size
+ else:
+
+ # - moe_kwargs is the constructed by get_moe_kwargs above
+ _num_experts = moe_kwargs["moe_num_experts"]
+ assert _num_experts % ep_size == 0, (
+ f"ep_shard factor '{ep_size}' does not divide "
+ f"number of experts '{_num_experts}'."
+ )
+
+ assert ep_size > 1, "expert_parallel dimension must be set larger than 1"
+ assert (
+ world_size % ep_size == 0
+ ), f"world_size ({world_size}) not divisible by ep_size ({ep_size})."
+
+ # this function will shard the MOE on this rank
+ device = torch.device(f"cuda:{rank}")
+
+ if shared_mesh_dim:
+ # in this case we will have a 1D mesh and collapse the
+ # expert parallel with data_parallel
+
+ device_mesh = init_device_mesh(
+ device_type,
+ (ep_size,),
+ mesh_dim_names=(key_dp,),
+ )
+ key_ep = key_dp
+ placements: List[Placement] = [Shard(DIM_EXPERT)]
+ else:
+ # in this case it will distribute experts on a different
+ # mesh dimension than dp.
+ # - this will achieve the effect that the expert sharding can be
+ # hierachical (e.g., can be over a slower network plane since
+ # the communication overhead is less
+ dp_size = world_size // ep_size
+ device_mesh = init_device_mesh(
+ device_type,
+ (dp_size, ep_size),
+ mesh_dim_names=(key_dp, key_ep),
+ )
+ # - experts will replicate over the first dimension
+ placements: List[Placement] = [Replicate(), Shard(DIM_EXPERT)]
+
+ mp_dmoe_args = arguments.Arguments(
+ **moe_kwargs,
+ device=device,
+ expert_parallel_group=device_mesh[key_ep].get_group(0),
+ )
+
+ assert mp_dmoe_args.moe_num_experts % ep_size == 0, (
+ f"number of moe experts ({mp_dmoe_args.moe_num_experts}) "
+ f"not divisible by ep_size ({ep_size})."
+ )
+
+ # for all the MoE related params, e.g., gate, experts
+ # get a dictionary
+ # parent_mod: (child_instance_name, [list of fqdn keys])
+ found = {}
+ for name, mod in model.named_modules():
+ name = name.split(".")
+ parent, child = ".".join(name[:-1]), name[-1]
+
+ # check the module depending if moe_cls is a str or class
+ if (
+ mod.__class__.__name__ == moe_cls
+ if isinstance(moe_cls, str)
+ else isinstance(mod, moe_cls)
+ ):
+ fqdn_keys = [ # all params, including childs'
+ f"{parent}.{child}.{n}" for n, _ in mod.named_parameters()
+ ]
+
+ # check if there are any biases in any of the experts
+ # if there are biases
+ # Assumption: assume that if one expert has bias,then the others
+ # will have it to
+ has_bias = any(expert_name in k and k.endswith("bias") for k in fqdn_keys)
+
+ found[parent] = (child, fqdn_keys, has_bias)
+
+ moe_module_names = set()
+
+ # NOTE: for now we only support sharded safetensors
+ # - most MOE models should be used using this checkpoint format
+ try:
+ loc = get_resolved_checkpoint_location(checkpoint_name_or_path)
+ with open(os.path.join(loc, FILE_SAFETENSOR_INDEX), encoding="utf-8") as f:
+ index = json.load(f)
+
+ # e.g., prefix: 'model.layers.0',
+ # module_name: 'block_sparse_moe'
+ for prefix, (module_name, _, has_bias) in tqdm(
+ found.items(), disable=(rank > 0), desc="Sharding MoE"
+ ):
+ checkpoint_metadata = get_checkpoint_meta_from_sharded_safetensor(
+ index["weight_map"], prefix, module_name, router_name, expert_name
+ )
+
+ _args = copy(mp_dmoe_args)
+ _args.bias = has_bias
+
+ # - will replace the MoE module with the megablocks sharded dMoE
+ with init_empty_weights():
+ mp_dmoe = dmoe.dMoE(_args) # drop in replacement for now
+
+ load_sharded_experts_onto_device(
+ mp_dmoe,
+ loc,
+ checkpoint_metadata,
+ device_mesh,
+ placements,
+ expert_name,
+ mixed_precision,
+ )
+ parent = model.get_submodule(prefix)
+ setattr(parent, module_name, mp_dmoe)
+
+ # - keep track of the name for returning
+ moe_module_names.add(module_name)
+
+ except ValueError as e:
+ raise ValueError(
+ f"Unable to load checkpoint_path '{checkpoint_name_or_path}'. "
+ "Currently only support non-GGUF safetensor checkpoints. "
+ ) from e
+
+ return device_mesh[key_dp], moe_module_names
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py
new file mode 100644
index 00000000..439eaaf0
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py
@@ -0,0 +1,144 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Third Party
+import torch
+
+try:
+ # definition is guarded, intended only when
+ # megablocks is available
+
+ # Third Party
+ # pylint: disable=import-error
+ from megablocks.layers import common, mpu
+ from megablocks.layers.activation_fn import act_fn
+ from megablocks.layers.arguments import Arguments
+ from megablocks.layers.mlp import (
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+ scale_gradient,
+ )
+ import stk
+
+ # This is the different MLP class used for models that have up_proj, down_proj
+ # and gate_proj like Mixtral
+ class SparseMLPv2(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(
+ args
+ ) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ )
+ self.w3 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ )
+
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ )
+ )
+ self.w3.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ )
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1, self._should_set_parallelism_attribute
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2, self._should_set_parallelism_attribute
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w3, self._should_set_parallelism_attribute
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, hidden_states, topo):
+ w1, w2, w3 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.w2),
+ self.scale_grad(self.w3),
+ )
+ w1, w2, w3 = (resolve_dtensor(w1), resolve_dtensor(w2), resolve_dtensor(w3))
+
+ # Perform the expert computation
+ hidden_states = stk.Matrix( # type: ignore
+ topo.size(),
+ act_fn(
+ stk.ops.sdd(hidden_states, w1.t(), topo), self.args.activation_fn
+ ).data
+ * stk.ops.sdd(hidden_states, w3.t(), topo).data,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+ return stk.ops.dsd(hidden_states, w2)
+
+except ImportError:
+ pass
diff --git a/plugins/accelerated-moe/tests/__init__.py b/plugins/accelerated-moe/tests/__init__.py
new file mode 100644
index 00000000..38a9531e
--- /dev/null
+++ b/plugins/accelerated-moe/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/plugins/accelerated-moe/tests/test_megablocks_plugin.py b/plugins/accelerated-moe/tests/test_megablocks_plugin.py
new file mode 100644
index 00000000..646e0a2b
--- /dev/null
+++ b/plugins/accelerated-moe/tests/test_megablocks_plugin.py
@@ -0,0 +1,34 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+import os
+
+# Third Party
+from fms_acceleration.utils import instantiate_framework, read_configuration
+
+# First Party
+from fms_acceleration_moe import MegablocksMoEAccelerationPlugin
+
+# configuration
+DIRNAME = os.path.dirname(__file__)
+CONFIG_PATH_MEGABLOCKS = os.path.join(DIRNAME, "../configs/megablocks.yaml")
+
+
+def test_framework_installs_aadp_padding_free_plugin():
+ with instantiate_framework(
+ read_configuration(CONFIG_PATH_MEGABLOCKS), require_packages_check=False
+ ) as framework:
+ for plugin in framework.active_plugins:
+ assert isinstance(plugin[1], MegablocksMoEAccelerationPlugin)
diff --git a/plugins/accelerated-moe/tox.ini b/plugins/accelerated-moe/tox.ini
new file mode 100644
index 00000000..811f1329
--- /dev/null
+++ b/plugins/accelerated-moe/tox.ini
@@ -0,0 +1,48 @@
+[tox]
+envlist = py, lint
+
+[testenv]
+deps =
+ pytest>=7
+ -e {toxinidir}
+skip_install = true
+commands =
+
+ # install the dependencies here to ensure
+ # the order
+ pip install -e {toxinidir}/../framework
+ pytest {posargs:tests}
+
+[testenv:lint]
+description = run linters
+skip_install = false
+deps =
+ -e {toxinidir}/../framework
+ pylint>=2.16.2,<=3.1.0
+commands =
+ pylint src tests
+allowlist_externals = pylint
+
+[testenv:fmt]
+description = format
+skip_install = true
+deps =
+ black>=22.12
+ isort>=5.11
+commands =
+ black {posargs:.}
+ isort {posargs:.}
+
+[testenv:build]
+description = build wheel
+deps =
+ build
+commands = python -m build -w
+skip_install = True
+
+[testenv:twinecheck]
+description = check wheel
+deps =
+ twine
+commands = twine check dist/*
+skip_install = True
\ No newline at end of file
diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py
index 6a81d977..d568ec13 100644
--- a/plugins/framework/src/fms_acceleration/constants.py
+++ b/plugins/framework/src/fms_acceleration/constants.py
@@ -21,4 +21,4 @@
# and activated.
# - hence the plugins that have model loaders should be on top of this list
-PLUGINS = ["peft", "foak", "aadp"]
+PLUGINS = ["peft", "foak", "aadp", "moe"]
diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml
index 09301193..5b7c1e65 100644
--- a/sample-configurations/CONTENTS.yaml
+++ b/sample-configurations/CONTENTS.yaml
@@ -67,4 +67,9 @@ framework_configs:
- accelerated-peft
- attention-and-distributed-packing
- fused-ops-and-kernels
- filename: accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml
\ No newline at end of file
+ filename: accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml
+
+ - shortname: moe-megablocks
+ plugins:
+ - accelerated-moe
+ filename: moe-megablocks-sample-configuration.yaml
diff --git a/sample-configurations/moe-megablocks-sample-configuration.yaml b/sample-configurations/moe-megablocks-sample-configuration.yaml
new file mode 100644
index 00000000..b1ee14f9
--- /dev/null
+++ b/sample-configurations/moe-megablocks-sample-configuration.yaml
@@ -0,0 +1,45 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ training:
+
+ # mixture-of-experts configurations
+ moe:
+
+ # expert-parallel for MoE
+ megablocks:
+
+ # The name of the mixture-of-experts class
+ moe_component_class: MixtralSparseMoeBlock
+
+ # The module name of the router in moe_component_class above
+ moe_gate_module_name: gate
+
+ # The module name of the experts in moe_component_class above
+ moe_experts_module_name: experts
+
+ # the mlp version
+ # - for those with only up and down projs, use "v1"
+ # - for those with only up, down and gate projs, use "v2"
+ moe_mlp_impl: v2
+
+ # if True, then we shard experts across data parallel dimension
+ # - only feasible if world_size divides the number of experts
+ shard_along_dp: true
+
+ # to be specified only if shard_along_dp == False. This will influence
+ # the level of sharding, which indicates how many experts per device
+ # - the number of experts per device will be num_experts / ep_size
+ # - we disable the ability to set ep_size=1 since this means no sharding
+ # - NOTE: ep_size=1 does not mean shard_along_dp=True, which would otherwise
+ # be contradictory since ep_size suggests no expert parallel.
+ # ep_size: 2
+
+ # the MoE dropless implementation. Currently we only support "dropless_sparse", but
+ # in the future we may support others
+ moe_implementation: dropless_sparse
+
+ # for load_balancing_loss
+ load_balancing_loss: false
diff --git a/scripts/benchmarks/accelerate.yaml b/scripts/benchmarks/accelerate.yaml
index f70d74fa..f3908470 100644
--- a/scripts/benchmarks/accelerate.yaml
+++ b/scripts/benchmarks/accelerate.yaml
@@ -30,7 +30,7 @@ fsdp_config:
# 3 is NO_SHARD, effectively disabling FSDP
# 4, 5 are HYBRID_ modes for multi-node training only.
- fsdp_state_dict_type: FULL_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3)
+ fsdp_state_dict_type: SHARDED_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3)
# 2 is LOCAL_STATE_DICT where parameters are still flattened
# 3 is efficient, but requires know-how to use the shared checkpoint.
diff --git a/scripts/benchmarks/accelerator-config.json b/scripts/benchmarks/accelerator-config.json
new file mode 100644
index 00000000..7f736f97
--- /dev/null
+++ b/scripts/benchmarks/accelerator-config.json
@@ -0,0 +1,5 @@
+{
+ "gradient_accumulation_kwargs": {
+ "sync_each_batch": true
+ }
+}
diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py
index d21b2fbe..19487652 100644
--- a/scripts/benchmarks/benchmark.py
+++ b/scripts/benchmarks/benchmark.py
@@ -356,14 +356,28 @@ class ScenarioMatrix:
def __init__(self, scenario: Dict, acceleration_config_map: Dict = None) -> None:
assert "arguments" in scenario.keys(), "Missing `arguments` key in `scenario`"
+
+ # "slow" is a special key that indicates this scenario
+ # takes resources to run
+ # - "slow" scenarios are not run if not specified by a filter
+ self.slow = False
+
for key, val in scenario.items():
if key == "framework_config":
# if acceleration_config_map is None, then do not do mapping
if acceleration_config_map:
+
+ # - we allow k to be None to indicate we do not wish to
+ # set a config for that matrix entry. However, we do not
+ # check for multiple None's, so be careful.
val = [
- acceleration_config_map[k]
+ (
+ acceleration_config_map[k]
+ if k is not None
+ else None
+ )
for k in val
- if k in acceleration_config_map
+ if k in acceleration_config_map or k is None
]
setattr(self, key, val)
@@ -679,7 +693,18 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset):
if args.run_only_scenarios and _scn_name not in args.run_only_scenarios:
print(f"Skipping scenario '{_scn_name}'")
continue
+
+ # build scenario matrix
scenario = ScenarioMatrix(scenario_config, acceleration_config_map)
+
+ if (
+ not args.run_only_scenarios
+ and scenarios.slow
+ ):
+ # unfiltered runs omit all "slow" marked scenarios
+ print(f"Skipping slow scenario '{_scn_name}' beacuse run_only_scenarios=None.")
+ continue
+
scenario_matrices, scenario_constants = (
scenario.get_scenario_matrices_and_defaults()
)
diff --git a/scripts/benchmarks/refs/a100_80gb_mb.csv b/scripts/benchmarks/refs/a100_80gb_mb.csv
new file mode 100644
index 00000000..b5bdc6c5
--- /dev/null
+++ b/scripts/benchmarks/refs/a100_80gb_mb.csv
@@ -0,0 +1,3 @@
+framework_config,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
+none,65598.5,58936741888,47259259904,bfloat16,0.859970542192459,4170.0391,3.07,0.024,80.575
+moe-megablocks,52284.0,48874301952,35987686400,bfloat16,0.8570401281118393,1404.3938,9.114,0.071,239.249
diff --git a/scripts/benchmarks/refs/requirements_mb.txt b/scripts/benchmarks/refs/requirements_mb.txt
new file mode 100644
index 00000000..679b20f8
--- /dev/null
+++ b/scripts/benchmarks/refs/requirements_mb.txt
@@ -0,0 +1,88 @@
+accelerate==0.33.0
+aiohappyeyeballs==2.4.0
+aiohttp==3.10.5
+aiosignal==1.3.1
+async-timeout==4.0.3
+attrs==24.2.0
+bitsandbytes==0.43.3
+certifi==2024.7.4
+charset-normalizer==3.3.2
+contourpy==1.2.1
+cycler==0.12.1
+datasets==2.21.0
+dill==0.3.8
+docstring_parser==0.16
+einops==0.8.0
+filelock==3.15.4
+fire==0.6.0
+flash-attn==2.6.3
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@9cf70081f3bfc00e84331102e7d13b333f17ee26#egg=fms_acceleration&subdirectory=plugins/framework
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@9cf70081f3bfc00e84331102e7d13b333f17ee26#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@9cf70081f3bfc00e84331102e7d13b333f17ee26#egg=fms_acceleration_moe&subdirectory=plugins/accelerated-moe
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@9cf70081f3bfc00e84331102e7d13b333f17ee26#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft
+fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@daca5510ab76cc8ecf0283fd31fc220697a75040
+fonttools==4.53.1
+frozenlist==1.4.1
+fsspec==2024.6.1
+huggingface-hub==0.24.6
+idna==3.7
+Jinja2==3.1.4
+kiwisolver==1.4.5
+markdown-it-py==3.0.0
+MarkupSafe==2.1.5
+matplotlib==3.9.2
+mdurl==0.1.2
+megablocks @ git+https://github.com/databricks/megablocks.git@bce5d7b2aaf5038bc93b36f76c2baf51c2939bd2
+mpmath==1.3.0
+multidict==6.0.5
+multiprocess==0.70.16
+networkx==3.3
+numpy==1.26.4
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==8.9.2.26
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.20.5
+nvidia-nvjitlink-cu12==12.6.20
+nvidia-nvtx-cu12==12.1.105
+packaging==24.1
+pandas==2.2.2
+peft==0.12.0
+pillow==10.4.0
+protobuf==5.27.3
+psutil==6.0.0
+pyarrow==17.0.0
+Pygments==2.18.0
+pyparsing==3.1.2
+python-dateutil==2.9.0.post0
+pytz==2024.1
+PyYAML==6.0.2
+regex==2024.7.24
+requests==2.32.3
+rich==13.7.1
+safetensors==0.4.4
+sentencepiece==0.2.0
+shtab==1.7.1
+simpleeval==0.9.13
+six==1.16.0
+stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301
+sympy==1.13.2
+termcolor==2.4.0
+threadpoolctl==3.5.0
+tokenizers==0.19.1
+torch==2.3.1
+tqdm==4.66.5
+transformers==4.44.2
+triton==2.3.1
+trl==0.9.6
+typing_extensions==4.12.2
+tyro==0.8.8
+tzdata==2024.1
+urllib3==2.2.2
+xxhash==3.5.0
+yarl==1.9.4
diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml
index 2eb22872..0e8b6954 100644
--- a/scripts/benchmarks/scenarios.yaml
+++ b/scripts/benchmarks/scenarios.yaml
@@ -94,20 +94,18 @@ scenarios:
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'
- - name: accelerated-peft-gptq
+ - name: accelerated-moe-megablocks
framework_config:
- - accelerated-peft-autogptq
- - accelerated-peft-autogptq-foak
+ - # without acceleration
+ - moe-megablocks
+ slow: True
arguments:
- learning_rate: 2e-4
- fp16: True
- torch_dtype: float16
- peft_method: lora
- r: 16
- lora_alpha: 16
- lora_dropout: 0.1
- target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
+ learning_rate: 5e-5
+ torch_dtype: bfloat16
+ accelerator_config: scripts/benchmarks/accelerator-config.json
+ gradient_accumulation_steps: 16
+ logging_steps: 1
+ packing: False
+ adam_epsilon: 1e-8
model_name_or_path:
- - 'TheBloke/Mistral-7B-v0.1-GPTQ'
- - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
- - 'TheBloke/Llama-2-70B-GPTQ'
+ - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py
index c72c62eb..2a740342 100644
--- a/scripts/generate_sample_configurations.py
+++ b/scripts/generate_sample_configurations.py
@@ -147,6 +147,7 @@ def read_configuration(path: str) -> Dict:
KEY_BNB_NF4_FOAK = "bnb-nf4-foak"
KEY_AADP_PADDING_FREE = "aadp-padding-free"
KEY_AADP_MULTIPACK = "aadp-multipack"
+KEY_MEGABLOCKS = "moe-megablocks"
CONFIGURATIONS = {
KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml",
@@ -171,6 +172,7 @@ def read_configuration(path: str) -> Dict:
),
KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/padding_free.yaml",
KEY_AADP_MULTIPACK: "plugins/attention-and-distributed-packing/configs/multipack.yaml",
+ KEY_MEGABLOCKS: "plugins/accelerated-moe/configs/megablocks.yaml",
}
# list of (tag, combi) tuples
@@ -190,6 +192,7 @@ def read_configuration(path: str) -> Dict:
("accelerated-peft-autogptq-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)),
("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)),
("aadp-padding-free-multipack", (KEY_AADP_PADDING_FREE, KEY_AADP_MULTIPACK)),
+ ("moe-megablocks", (KEY_MEGABLOCKS,)),
]
diff --git a/tox.ini b/tox.ini
index 52f9bdb3..cad75c25 100644
--- a/tox.ini
+++ b/tox.ini
@@ -39,6 +39,7 @@ commands =
python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-peft
python -m fms_acceleration.cli install -e {toxinidir}/plugins/fused-ops-and-kernels
python -m fms_acceleration.cli install -e {toxinidir}/plugins/attention_and_distributed_packing
+ python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-moe
# run the benchmark script
bash scripts/run_benchmarks.sh {posargs:"1 2" "4 8" benchmark_outputs}