diff --git a/plugins/accelerated-moe/.pylintrc b/plugins/accelerated-moe/.pylintrc index 32b5dc66..4ddccea1 100644 --- a/plugins/accelerated-moe/.pylintrc +++ b/plugins/accelerated-moe/.pylintrc @@ -52,7 +52,7 @@ ignore=CVS,protobufs # ignore-list. The regex matches against paths and can be in Posix or Windows # format. Because '\\' represents the directory delimiter on Windows systems, # it can't be used as an escape character. -ignore-paths=.*megablocks +ignore-paths=.*megablocks,.*khd # Files or directories matching the regular expression patterns are skipped. # The regex matches against base names, not paths. The default value ignores diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index 45ae209d..d05cdc82 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -48,7 +48,6 @@ Run the below in the top-level directory of this repo: ``` tox -e run-benches \ - -x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-khd.txt" \ -x testenv:run-benches.setenv+="MEMORY_LOGGING=nvidia" \ -- \ "1 2 4" 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-full @@ -77,12 +76,7 @@ bash scripts/run_benchmarks.sh \ ### Triton Kernel Dependencies -Currently we do not copy the `scattermoe` kernels into this respository, to this is an additional manual install: - -``` -# this will install the kernel-hyperdrive fork with the scattermoe triton kernels -pip install -r requirements-khd.txt -``` +Triton Kernels are copied into [scattermoe_utils](./src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels) and were copied from [kernel hyperdrive](https://github.com/fabianlim/kernel-hyperdrive) which is a fork of [cute kernels](https://github.com/mayank31398/cute-kernels) ### Known Issues diff --git a/plugins/accelerated-moe/requirements-khd.txt b/plugins/accelerated-moe/requirements-khd.txt deleted file mode 100644 index 497bf78e..00000000 --- a/plugins/accelerated-moe/requirements-khd.txt +++ /dev/null @@ -1,2 +0,0 @@ -# fork of https://github.com/mayank31398/kernel-hyperdrive/ -kernel-hyperdrive @ git+https://github.com/fabianlim/kernel-hyperdrive.git \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 3a4a3c27..1bb33871 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -32,14 +32,6 @@ # pylint: disable=too-many-instance-attributes class ScatterMoEAccelerationPlugin(AccelerationPlugin): - # NOTE: we cannot do - # - require_packages = {"khd"} - # this is because the khd fork is not properly packaged as a PyPI project, and so - # - "importlib.util.find_spec('khd')" returns, but - # - "importlib.metadata.version('kernel-hyperdrive')" does not return - # if we decide to extract the kernels, then we do not need to anymore, - # https://github.com/foundation-model-stack/fms-acceleration/issues/105 - restricted_model_archs = [ "GraniteMoeForCausalLM", "MixtralForCausalLM", diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index d26139e5..e16943b1 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -26,21 +26,13 @@ import torch import torch.nn.functional as F -try: - # Third Party - from khd.kernels.scattermoe.triton_implementation.ops import ( - padded_block_indices, - scattered_experts, - ) -except ImportError as e: - raise ImportError( - "kernel-hyperdrive PyPI package not found. Install it: " - "pip install -r plugins/accelerated-moe/requirements-khd.txt" - ) from e - # Local from .scattermoe_constants import SCATTERMOE_SPEC_HAS_GATE from .scattermoe_utils import all_to_all_gather_inputs, scatter_with_routing_weights +from .scattermoe_utils.khd.kernels.ops import ( + padded_block_indices, + scattered_experts, +) # helper function to fetch the local tensor if its a dtensor diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/__init__.py new file mode 100644 index 00000000..2b4a99c6 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/__init__.py @@ -0,0 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# Copyright 2024 Databricks +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .custom_op import torch_custom_op diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/custom_op.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/custom_op.py new file mode 100644 index 00000000..77efe812 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/custom_op.py @@ -0,0 +1,47 @@ +# Standard +from typing import Any, Callable, Iterable + +# Third Party +import torch + +try: + # Third Party + from torch.library import custom_op + + _IS_CUSTOM_OP_IN_PYTORCH = True +except: + _IS_CUSTOM_OP_IN_PYTORCH = False + + +class _IdentityOp: + def __init__(self, fn: Callable) -> None: + self.fn = fn + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.fn(*args, **kwargs) + + def register_fake(self, fn: Callable) -> Callable: + return fn + + +def torch_custom_op( + name: str, + fn: Callable | None = None, + /, + *, + mutates_args: str | Iterable[str], + device_types: torch.device = None, + schema: str | None = None, +) -> Callable | _IdentityOp: + if _IS_CUSTOM_OP_IN_PYTORCH: + op = custom_op( + name, + fn, + mutates_args=mutates_args, + device_types=device_types, + schema=schema, + ) + else: + op = _IdentityOp if fn is None else _IdentityOp(fn) + + return op diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/__init__.py new file mode 100644 index 00000000..c129198a --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/__init__.py @@ -0,0 +1,22 @@ +# Copyright The FMS HF Tuning Authors +# Copyright 2024 Databricks +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .kernels import ( + group_triton_kernel, + groupXtY_triton_kernel, + scatter2scatter_lora_triton_kernel, + scatter2scatter_triton_kernel, +) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/kernels.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/kernels.py new file mode 100644 index 00000000..1e713f21 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/kernels.py @@ -0,0 +1,454 @@ +# Kernels from Cute Kernels https://github.com/mayank31398/cute-kernels + +# Third Party +import triton +import triton.language as tl + +BLOCK_M = 128 + + +@triton.autotune( + configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], + key=["N", "K"], +) +@triton.jit +def scatter2scatter_triton_kernel( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + Y_ptr, + stride_ym, + stride_yn, + grouped_idx_ptr, + expert_idxs_ptr, + block_start_idx_ptr, + FAN_OUT, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + x_grouped, + y_grouped, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load(block_start_idx_ptr + M_block_id) + + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) + + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = ( + W_ptr + + K_block[:, None] * stride_wk + + N_block[None, :] * stride_wn + + E_idx * stride_we + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(K, BLOCK_K) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for K_block_id in range(0, iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + + if no_n_mask or K_block_id < (iters - 1): + w = tl.load(W_blk_ptrs) + else: + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +@triton.autotune( + configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], + key=["N", "K"], +) +@triton.jit +def scatter2scatter_lora_triton_kernel( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + A_ptr, + stride_ae, + stride_ak, + stride_ar, + B_ptr, + stride_be, + stride_br, + stride_bn, + Y_ptr, + stride_ym, + stride_yn, + grouped_idx_ptr, + expert_idxs_ptr, + block_start_idx_ptr, + FAN_OUT: tl.constexpr, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + R: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + scaling, + allow_tf32: tl.constexpr, + x_grouped: tl.constexpr, + y_grouped: tl.constexpr, +): + pid = tl.program_id(axis=0) + + # The grid is assumed to be num_padded_blocks * N_BLOCK_COUNT. The input tokens are + # on the M dimension, that is blocked. The first task is to identify which expert + # is being worked on by the kernel instance. + # - block_start_idx_ptr contains offsets that allow the processing to occur in M Blocks + # - one padded block could contain multiple experts, in which case block_start_idx_ptr + # will index multiple times into a single BLOCK_M. + # - e.g., block_start_idx_ptr = [0, 128, 256, 300, 428] + # * there are 300 E_idx = 0 tokens, this requires 3 * 128 blocks to process, at + # offsets 0, 128, 256. + # * the remaining are E_idx = 1 tokens, where the first 128 block starts at 300, + # then goes on to 428, etc. + # - use start index to instantiate the M_block + # - M_block indices the X's worked on by this kernel instance + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + PB_idx = pid // N_BLOCK_COUNT # padded block index + N_block_id = pid % N_BLOCK_COUNT + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load(block_start_idx_ptr + PB_idx) # M block starts from here + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) + + # Assumption: expert_idxs_ptr is a sorted list of expert ids. + # - load expert_idxs_ptr into E_idxs + # - construct the E_mask so we operate only on expert for this kernel instance (e.g., E_idx) + # - in cases where the M_block may overlap multiple experts, then tl.min(E_idxs) or + # E_idxs[0] (if it is sorted) can be used to infer the expert being worked on + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + E_idx = tl.min(E_idxs) # if we do this we do not need to index the tensor + E_mask = E_idxs == E_idx + + # Assumption: grouped_idx_ptr puts X in the order as expected by expert_idxs_ptr. + # - same length as expert_idxs_ptr + + # depending on grouped settings, set M_in_idx (input) and M_out_idx (output) appropriately + # - if already grouped, then M_idx is not required and use M_block + if x_grouped: + M_in_idx = M_block + else: + M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) + M_in_idx = M_idx // FAN_OUT + + if y_grouped: + M_out_idx = M_block + else: + M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) + M_out_idx = M_idx + + # - K_block for input dimension + K_block = tl.arange(0, BLOCK_K) + + # - N_block for output dimension + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + + # - R range for lora dimension + R_range = tl.arange(0, R) + + # X: dimensions M, K + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + + # W: dimensions E, K, N + # A: dimensions E, K, lora_r + # B: dimensions E, lora_r, N + W_blk_ptrs = ( + W_ptr + + E_idx * stride_we + + K_block[:, None] * stride_wk + + N_block[None, :] * stride_wn + ) + A_blk_ptrs = ( + A_ptr + + E_idx * stride_ae + + K_block[:, None] * stride_ak + + R_range[None, :] * stride_ar + ) + B_blk_ptrs = ( + B_ptr + + E_idx * stride_be + + N_block[None, :] * stride_bn + + R_range[:, None] * stride_br + ) + + # b can be loaded outside because it has no dependence on input dimension K + b = tl.load(B_blk_ptrs, mask=N_mask[None, :]) + + # for masking + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + # accumulate loop over input dimension, for iters number of times + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(K, BLOCK_K) + for K_block_id in range(0, iters): + + # - load x, w, a quantities depending on NO_K_MASK or NO_N_MASK + if no_k_mask: + # - if K mask not required + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + a = tl.load(A_blk_ptrs) + + if no_n_mask or K_block_id < (iters - 1): + # - if N mask also not reqiured + w = tl.load(W_blk_ptrs) + else: + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + # - construct K mask (NO_N_MASK has no effect here) + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + a = tl.load(A_blk_ptrs, mask=K_mask[:, None]) + + # Y = X * (W + A*B*scaling) + # = X * W + X * A * B * scaling + # - acummulate base layer + acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) + + # - accumulate adapter + # - interim = X * A * scaling + # - interm wil be of dimensions M_block by lora_r + interim = tl.dot(x, a) + interim *= scaling + acc += tl.dot(interim.to(b.dtype), b, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) + + # move pointers in K + # NOTE: b has no dependence on K, so it doesnt need to move + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + A_blk_ptrs += BLOCK_K * stride_ak + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +@triton.autotune( + configs=[ + # different block M and reducing stages + triton.Config( + {"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4 + ), + triton.Config( + {"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4 + ), + # keep 4 stages and keep two 64 block sizes + # - NOTE: these can get good performances for low M, but for large M the variation + # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + ], + key=["N", "K"], +) +@triton.jit +def groupXtY_triton_kernel( + DY_ptr, + stride_dym, + stride_dyk, + X_ptr, + stride_xm, + stride_xn, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + expert_offsets_ptr, + K: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = ( + DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + ) + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + + if no_k_mask: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + + if no_n_mask: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + + DW_blk_ptrs = ( + DW_ptr + + E_idx * stride_dwe + + K_block[:, None] * stride_dwk + + N_block[None, :] * stride_dwn + ) + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4) + ], + key=["K"], +) +@triton.jit +def group_triton_kernel( + src_ptr, + stride_sn, + stride_sk, + has_coeff: tl.constexpr, + coeff_ptr, + FAN_OUT, + tgt_ptr, + stride_tn, + stride_ti, + grouped_idx_ptr, + N, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = ( + src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + ) + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + no_k_mask = K % BLOCK_K == 0 + + for i in range(0, iters): + if no_k_mask or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=mask) + + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/ops/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/ops/__init__.py new file mode 100644 index 00000000..35616fb1 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/ops/__init__.py @@ -0,0 +1,310 @@ +# Copyright The FMS HF Tuning Authors +# Copyright 2024 Cute Kernels +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Third Party +import torch + +# Local +from .compileable_ops import compileable_bincount, group, group_bwd_W, scatter2scatter + +BLOCK_M = 128 +torch._dynamo.config.capture_scalar_outputs = True + + +def padded_block_indices( + sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M +): + # there is an overhead of launching a custom op so we only use the custom op when compiling + if torch.compiler.is_compiling(): + expert_counts = compileable_bincount(sorted_experts_idxs, k) + else: + expert_counts = sorted_experts_idxs.bincount(minlength=k) + + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 + padded_expert_block_end = padded_block_counts.cumsum(-1) + expert_boundaries_end = expert_counts.cumsum(-1) + expert_boundaries_start = expert_boundaries_end - expert_counts + padded_expert_block_start = padded_expert_block_end - padded_block_counts + + block_idxs = torch.arange( + padded_expert_block_end[-1], + dtype=sorted_experts_idxs.dtype, + device=sorted_experts_idxs.device, + ).unsqueeze(1) + + block_mask = (block_idxs < padded_expert_block_start) | ( + block_idxs >= padded_expert_block_end + ) + expanded_block_idxs = ( + N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + + expert_boundaries_start + ) + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) + + return expanded_block_idxs, expert_boundaries_end + + +class _ScatteredExperts(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + expert_lora_A=None, + expert_lora_B=None, + lora_alp: float = 0.0, + ): + output = torch.empty( + sorted_expert_idxs.size(0), + expert_weights.size(-1), + device=x.device, + dtype=x.dtype, + ) + + scatter2scatter( + X=x, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=output, + FAN_OUT=k, + x_grouped=grouped_in, + y_grouped=grouped_out, + A=expert_lora_A, + B=expert_lora_B, + lora_alp=lora_alp, + ) + + _extra_tensors_to_save = () + if lora_alp > 0 and expert_lora_A is not None and expert_lora_B is not None: + _extra_tensors_to_save = (expert_lora_A, expert_lora_B) + + # save some extra context + ctx.lora_r = expert_lora_A.size(-1) + ctx.lora_alp = lora_alp + + if gates is None: + output_expanded = None + else: + output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) + output = torch.bmm(gates.unsqueeze(1), output_expanded).squeeze(1) + + ctx.save_for_backward( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + output_expanded, + *_extra_tensors_to_save, + ) + + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + + return output + + @staticmethod + def backward(ctx, grad_out): + ( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + output_expanded, + *_extra_saved_tensors, + ) = ctx.saved_tensors + k = ctx.k + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + + use_lora = False + if hasattr(ctx, "lora_r"): + lora_r = ctx.lora_r + lora_alp = ctx.lora_alp + expert_lora_A, expert_lora_B = _extra_saved_tensors + use_lora = True + + if gates is None: + d_gates = None + gates_flat = None + gate_fan = 1 + # grouped_grad_out = None + else: + # calculate gates gradient + d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + # print("expanded and grouping") + # grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later + + if grouped_out: + grouped_grad_out = grad_out + else: + grouped_grad_out = torch.zeros( + (grad_out.shape[0] * gate_fan, grad_out.shape[1]), + dtype=grad_out.dtype, + device=grad_out.device, + ) + group( + A=grad_out, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_grad_out, + coeff=gates_flat, + fan_out=gate_fan, + ) + + if grouped_in: + grouped_x = x + d_expanded_input = torch.empty( + sorted_expert_idxs.size(0), + expert_weights.size(1), + device=x.device, + dtype=x.dtype, + ) + else: + grouped_x = torch.empty( + sorted_scattered_idxs.size(0), x.size(1), dtype=x.dtype, device=x.device + ) + group( + A=x, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_x, + fan_out=k, + ) + + d_expanded_input = grouped_x + + d_weights = torch.zeros( + expert_weights.size(0), + grouped_grad_out.size(-1), + grouped_x.size(-1), + device=grouped_grad_out.device, + dtype=grouped_grad_out.dtype, + ).permute(0, 2, 1) + + group_bwd_W( + DY=grouped_grad_out, + X=grouped_x, + expert_offsets=expert_offsets, + DW=d_weights, + E=expert_weights.size(0), + ) + + _extra_scatter_kwargs = {} + _extra_grads_to_return = (None, None) + if use_lora: + d_weights_A = ( + d_weights @ expert_lora_B.permute(0, 2, 1) * (lora_alp / lora_r) + ) + d_weights_B = ( + expert_lora_A.permute(0, 2, 1) @ d_weights * (lora_alp / lora_r) + ) + d_weights = None # zero it + + _extra_scatter_kwargs = { + "A": expert_lora_B.permute(0, 2, 1), # B^T + "B": expert_lora_A.permute(0, 2, 1), # A^T + "lora_alp": lora_alp, + } + _extra_grads_to_return = (d_weights_A, d_weights_B) + + scatter2scatter( + X=grouped_grad_out, + W=expert_weights.permute(0, 2, 1), + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=d_expanded_input, + FAN_OUT=1, + x_grouped=True, + y_grouped=grouped_in, + **_extra_scatter_kwargs, + ) + + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view( + x.size(0), k, d_expanded_input.size(-1) + ).sum(-2) + + # print("backward end.") + return ( + # x, expert_weights, k, + d_input, + d_weights, + None, + # sorted_expert_idxs, sorted_scattered_idxs, + None, + None, + # padded_block_idxs, expert_offsets, + None, + None, + # gates + d_gates, + None, + None, + # adapter stuff + *_extra_grads_to_return, + None, + ) + + +def scattered_experts( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + expert_lora_A=None, + expert_lora_B=None, + lora_alp: float = 0.0, +): + return _ScatteredExperts.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + grouped_in, + grouped_out, + expert_lora_A, + expert_lora_B, + lora_alp, + ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/ops/compileable_ops.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/ops/compileable_ops.py new file mode 100644 index 00000000..4243057a --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/khd/kernels/ops/compileable_ops.py @@ -0,0 +1,351 @@ +# Compileable Ops from Cute Kernels https://github.com/mayank31398/cute-kernels + +# Third Party +import torch +import triton +import triton.language as tl + +# Local +from ...custom_op import torch_custom_op +from ..kernels import ( + group_triton_kernel, + groupXtY_triton_kernel, + scatter2scatter_lora_triton_kernel, + scatter2scatter_triton_kernel, +) + +LIBRARY_NAME = "khd" +BLOCK_M = 128 +torch._dynamo.config.capture_scalar_outputs = True + + +# bincount is not compilable +@torch_custom_op(f"{LIBRARY_NAME}::bincount", mutates_args={}) +def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: + return x.bincount(minlength=minlength) + + +@compileable_bincount.register_fake +def _(x: torch.Tensor, minlength: int) -> torch.Tensor: + return torch.empty(minlength, dtype=torch.long, device=x.device) + + +def _scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, + A: torch.Tensor = None, + B: torch.Tensor = None, + lora_alp: float = 0.0, +) -> None: + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * FAN_OUT + assert out.size(0) == sorted_expert_idxs.size(0) + assert out.size(1) == W.size(-1) + + grid = lambda meta: ( + padded_block_idxs.size(0) * triton.cdiv(meta["N"], meta["BLOCK_N"]), + ) + + # dispatch happens here + if A is not None and B is not None: + assert W.size(1) == A.size(1), "A has incorrect input size." + assert W.size(2) == B.size(2), "B has incorrect output size." + assert A.size(2) == B.size(1), "A and B have inconsistent inner dims." + + # This is the lora enabled version of scatter2scatter, where alongisde the weights + # W, we take in the adapters A and B. + # - in lora adaption the combined weights are W + A*B*scaling + # - scaling is typically lora_alp / lora_r + scatter2scatter_lora_triton_kernel[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # A_ptr, stride_ae, stride_ak, stride_ar, + A, + A.stride(0), + A.stride(1), + A.stride(2), + # B_ptr, stride_be, stride_br, stride_bn, + B, + B.stride(0), + B.stride(1), + B.stride(2), + # Y_ptr, stride_ym, stride_yn, + out, + out.stride(0), + out.stride(1), + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + block_start_idx_ptr=padded_block_idxs, + FAN_OUT=FAN_OUT, + M=X.size(0), + K=X.size(1), + N=out.size(1), + E=W.size(0), + R=A.size(2), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + scaling=(lora_alp / A.size(2)), + allow_tf32=True, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + return + + scatter2scatter_triton_kernel[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # Y_ptr, stride_ym, stride_yn, + out, + out.stride(0), + out.stride(1), + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + block_start_idx_ptr=padded_block_idxs, + FAN_OUT=FAN_OUT, + M=X.size(0), + K=X.size(1), + N=out.size(1), + E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::scatter2scatter", mutates_args={"out"}) +def _scatter2scatter_compileable( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, + A: torch.Tensor = None, + B: torch.Tensor = None, + lora_alp: float = 0.0, +) -> None: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + A=A, + B=B, + lora_alp=lora_alp, + ) + + +def scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, + A: torch.Tensor = None, + B: torch.Tensor = None, + lora_alp: float = 0.0, +) -> None: + if torch.compiler.is_compiling(): + _scatter2scatter_compileable( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + A=A, + B=B, + lora_alp=lora_alp, + ) + else: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + A=A, + B=B, + lora_alp=lora_alp, + ) + + +def _group_bwd_W( + DY: torch.Tensor, + X: torch.Tensor, + expert_offsets: torch.Tensor, + DW: torch.Tensor, + E: int, +) -> None: + grid = lambda meta: ( + E * triton.cdiv(meta["K"], meta["BLOCK_K"]), + triton.cdiv(meta["N"], meta["BLOCK_N"]), + ) + + groupXtY_triton_kernel[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, + DY.stride(0), + DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, + X.stride(0), + X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, + DW.stride(0), + DW.stride(1), + DW.stride(2), + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + N=DY.size(-1), + K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group_bwd_W", mutates_args={"DW"}) +def _group_bwd_W_compileable( + DY: torch.Tensor, + X: torch.Tensor, + expert_offsets: torch.Tensor, + DW: torch.Tensor, + E: int, +) -> None: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def group_bwd_W( + DY: torch.Tensor, + X: torch.Tensor, + expert_offsets: torch.Tensor, + DW: torch.Tensor, + E: int, +) -> None: + if torch.compiler.is_compiling(): + _group_bwd_W_compileable(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + else: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def _group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + + grid = lambda meta: (triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + group_triton_kernel[grid]( + # A_ptr, stride_an, stride_ai, + A, + A.stride(0), + A.stride(1), + coeff is not None, + coeff, + fan_out, + # Y_ptr, stride_yn, stride_yk, + out, + out.stride(0), + out.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, + K, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group", mutates_args={"out"}) +def _group_compileable( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + _group( + A=A, + sorted_expert_idxs=sorted_expert_idxs, + out=out, + coeff=coeff, + fan_out=fan_out, + ) + + +def group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + if torch.compiler.is_compiling(): + _group_compileable( + A=A, + sorted_expert_idxs=sorted_expert_idxs, + out=out, + coeff=coeff, + fan_out=fan_out, + ) + else: + _group( + A=A, + sorted_expert_idxs=sorted_expert_idxs, + out=out, + coeff=coeff, + fan_out=fan_out, + )