Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/accelerated-moe/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ignore=CVS,protobufs
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=.*megablocks
ignore-paths=.*megablocks,.*khd

# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
Expand Down
8 changes: 1 addition & 7 deletions plugins/accelerated-moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ Run the below in the top-level directory of this repo:

```
tox -e run-benches \
-x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-khd.txt" \
-x testenv:run-benches.setenv+="MEMORY_LOGGING=nvidia" \
-- \
"1 2 4" 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-full
Expand Down Expand Up @@ -77,12 +76,7 @@ bash scripts/run_benchmarks.sh \

### Triton Kernel Dependencies

Currently we do not copy the `scattermoe` kernels into this respository, to this is an additional manual install:

```
# this will install the kernel-hyperdrive fork with the scattermoe triton kernels
pip install -r requirements-khd.txt
```
Triton Kernels are copied into [scattermoe_utils](./src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels) and were copied from [kernel hyperdrive](https://github.com/fabianlim/kernel-hyperdrive) which is a fork of [cute kernels](https://github.com/mayank31398/cute-kernels)

### Known Issues

Expand Down
2 changes: 0 additions & 2 deletions plugins/accelerated-moe/requirements-khd.txt

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@
# pylint: disable=too-many-instance-attributes
class ScatterMoEAccelerationPlugin(AccelerationPlugin):

# NOTE: we cannot do
# - require_packages = {"khd"}
# this is because the khd fork is not properly packaged as a PyPI project, and so
# - "importlib.util.find_spec('khd')" returns, but
# - "importlib.metadata.version('kernel-hyperdrive')" does not return
# if we decide to extract the kernels, then we do not need to anymore,
# https://github.com/foundation-model-stack/fms-acceleration/issues/105

restricted_model_archs = [
"GraniteMoeForCausalLM",
"MixtralForCausalLM",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,13 @@
import torch
import torch.nn.functional as F

try:
# Third Party
from khd.kernels.scattermoe.triton_implementation.ops import (
padded_block_indices,
scattered_experts,
)
except ImportError as e:
raise ImportError(
"kernel-hyperdrive PyPI package not found. Install it: "
"pip install -r plugins/accelerated-moe/requirements-khd.txt"
) from e

# Local
from .scattermoe_constants import SCATTERMOE_SPEC_HAS_GATE
from .scattermoe_utils import all_to_all_gather_inputs, scatter_with_routing_weights
from .scattermoe_utils.khd.kernels.ops import (
padded_block_indices,
scattered_experts,
)


# helper function to fetch the local tensor if its a dtensor
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The FMS HF Tuning Authors
# Copyright 2024 Databricks
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .custom_op import torch_custom_op
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Standard
from typing import Any, Callable, Iterable

# Third Party
import torch

try:
# Third Party
from torch.library import custom_op

_IS_CUSTOM_OP_IN_PYTORCH = True
except:
_IS_CUSTOM_OP_IN_PYTORCH = False


class _IdentityOp:
def __init__(self, fn: Callable) -> None:
self.fn = fn

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.fn(*args, **kwargs)

def register_fake(self, fn: Callable) -> Callable:
return fn


def torch_custom_op(
name: str,
fn: Callable | None = None,
/,
*,
mutates_args: str | Iterable[str],
device_types: torch.device = None,
schema: str | None = None,
) -> Callable | _IdentityOp:
if _IS_CUSTOM_OP_IN_PYTORCH:
op = custom_op(
name,
fn,
mutates_args=mutates_args,
device_types=device_types,
schema=schema,
)
else:
op = _IdentityOp if fn is None else _IdentityOp(fn)

return op
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright The FMS HF Tuning Authors
# Copyright 2024 Databricks
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .kernels import (
group_triton_kernel,
groupXtY_triton_kernel,
scatter2scatter_lora_triton_kernel,
scatter2scatter_triton_kernel,
)
Loading