Skip to content

Commit de9a4f1

Browse files
authored
fix(deps): Copy KHD imports into scattermoe_utils (#127)
* fix: move necessary khd functions into scattermoe_utils Signed-off-by: Will Johnson <[email protected]> * remove requirements-khd, update README Signed-off-by: Will Johnson <[email protected]> * fix: remove ext dep from benchmarks Signed-off-by: Will Johnson <[email protected]> * fix: move files to khd folder Signed-off-by: Will Johnson <[email protected]> * fmt Signed-off-by: Will Johnson <[email protected]> * fix lint Signed-off-by: Will Johnson <[email protected]> * rm copyright Signed-off-by: Will Johnson <[email protected]> * reference cute kernels Signed-off-by: Will Johnson <[email protected]> * cute kernels reference Signed-off-by: Will Johnson <[email protected]> --------- Signed-off-by: Will Johnson <[email protected]>
1 parent 791bdd9 commit de9a4f1

File tree

11 files changed

+1207
-30
lines changed

11 files changed

+1207
-30
lines changed

plugins/accelerated-moe/.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ ignore=CVS,protobufs
5252
# ignore-list. The regex matches against paths and can be in Posix or Windows
5353
# format. Because '\\' represents the directory delimiter on Windows systems,
5454
# it can't be used as an escape character.
55-
ignore-paths=.*megablocks
55+
ignore-paths=.*megablocks,.*khd
5656

5757
# Files or directories matching the regular expression patterns are skipped.
5858
# The regex matches against base names, not paths. The default value ignores

plugins/accelerated-moe/README.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ Run the below in the top-level directory of this repo:
4848

4949
```
5050
tox -e run-benches \
51-
-x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-khd.txt" \
5251
-x testenv:run-benches.setenv+="MEMORY_LOGGING=nvidia" \
5352
-- \
5453
"1 2 4" 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-full
@@ -77,12 +76,7 @@ bash scripts/run_benchmarks.sh \
7776

7877
### Triton Kernel Dependencies
7978

80-
Currently we do not copy the `scattermoe` kernels into this respository, to this is an additional manual install:
81-
82-
```
83-
# this will install the kernel-hyperdrive fork with the scattermoe triton kernels
84-
pip install -r requirements-khd.txt
85-
```
79+
Triton Kernels are copied into [scattermoe_utils](./src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels) and were copied from [kernel hyperdrive](https://github.com/fabianlim/kernel-hyperdrive) which is a fork of [cute kernels](https://github.com/mayank31398/cute-kernels)
8680

8781
### Known Issues
8882

plugins/accelerated-moe/requirements-khd.txt

Lines changed: 0 additions & 2 deletions
This file was deleted.

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@
3232
# pylint: disable=too-many-instance-attributes
3333
class ScatterMoEAccelerationPlugin(AccelerationPlugin):
3434

35-
# NOTE: we cannot do
36-
# - require_packages = {"khd"}
37-
# this is because the khd fork is not properly packaged as a PyPI project, and so
38-
# - "importlib.util.find_spec('khd')" returns, but
39-
# - "importlib.metadata.version('kernel-hyperdrive')" does not return
40-
# if we decide to extract the kernels, then we do not need to anymore,
41-
# https://github.com/foundation-model-stack/fms-acceleration/issues/105
42-
4335
restricted_model_archs = [
4436
"GraniteMoeForCausalLM",
4537
"MixtralForCausalLM",

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,13 @@
2626
import torch
2727
import torch.nn.functional as F
2828

29-
try:
30-
# Third Party
31-
from khd.kernels.scattermoe.triton_implementation.ops import (
32-
padded_block_indices,
33-
scattered_experts,
34-
)
35-
except ImportError as e:
36-
raise ImportError(
37-
"kernel-hyperdrive PyPI package not found. Install it: "
38-
"pip install -r plugins/accelerated-moe/requirements-khd.txt"
39-
) from e
40-
4129
# Local
4230
from .scattermoe_constants import SCATTERMOE_SPEC_HAS_GATE
4331
from .scattermoe_utils import all_to_all_gather_inputs, scatter_with_routing_weights
32+
from .scattermoe_utils.khd.kernels.ops import (
33+
padded_block_indices,
34+
scattered_experts,
35+
)
4436

4537

4638
# helper function to fetch the local tensor if its a dtensor
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright The FMS HF Tuning Authors
2+
# Copyright 2024 Databricks
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Local
17+
from .custom_op import torch_custom_op
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Standard
2+
from typing import Any, Callable, Iterable
3+
4+
# Third Party
5+
import torch
6+
7+
try:
8+
# Third Party
9+
from torch.library import custom_op
10+
11+
_IS_CUSTOM_OP_IN_PYTORCH = True
12+
except:
13+
_IS_CUSTOM_OP_IN_PYTORCH = False
14+
15+
16+
class _IdentityOp:
17+
def __init__(self, fn: Callable) -> None:
18+
self.fn = fn
19+
20+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
21+
return self.fn(*args, **kwargs)
22+
23+
def register_fake(self, fn: Callable) -> Callable:
24+
return fn
25+
26+
27+
def torch_custom_op(
28+
name: str,
29+
fn: Callable | None = None,
30+
/,
31+
*,
32+
mutates_args: str | Iterable[str],
33+
device_types: torch.device = None,
34+
schema: str | None = None,
35+
) -> Callable | _IdentityOp:
36+
if _IS_CUSTOM_OP_IN_PYTORCH:
37+
op = custom_op(
38+
name,
39+
fn,
40+
mutates_args=mutates_args,
41+
device_types=device_types,
42+
schema=schema,
43+
)
44+
else:
45+
op = _IdentityOp if fn is None else _IdentityOp(fn)
46+
47+
return op
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright The FMS HF Tuning Authors
2+
# Copyright 2024 Databricks
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Local
17+
from .kernels import (
18+
group_triton_kernel,
19+
groupXtY_triton_kernel,
20+
scatter2scatter_lora_triton_kernel,
21+
scatter2scatter_triton_kernel,
22+
)

0 commit comments

Comments
 (0)