Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import torch

# Local
from .models.utils import filter_mp_rules
from .utils import lora_adapters_switch_ddp_from_fsdp
from .utils import filter_mp_rules, lora_adapters_switch_ddp_from_fsdp


# consider rewriting register_foak_model_patch_rules into something
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
from .utils import (
KEY_MLP,
KEY_O,
KEY_QKV,
build_lora_fused_ops,
filter_mp_rules,
get_hidden_activation_fn_key,
trigger_fused_ops,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
from .utils import (
KEY_MLP,
KEY_O,
KEY_QKV,
build_lora_fused_ops,
filter_mp_rules,
get_hidden_activation_fn_key,
trigger_fused_ops,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
from .utils import (
KEY_MLP,
KEY_O,
KEY_QKV,
build_lora_fused_ops,
filter_mp_rules,
get_hidden_activation_fn_key,
trigger_fused_ops,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Standard
from functools import partial
from typing import Callable, List, Set, Type
from typing import Callable, List, Type
import os

# Third Party
from fms_acceleration.model_patcher import ModelPatcherRule, ModelPatcherTrigger
from fms_acceleration.model_patcher import ModelPatcherTrigger
from transformers import PretrainedConfig
import torch

Expand Down Expand Up @@ -203,22 +203,6 @@ def trigger_fused_ops(
return isinstance(module, attn_cls) and all(_is_loralayer(x) for x in _mods)


# helper function to filter rules
def filter_mp_rules(
rules: List[ModelPatcherRule],
filter_endswith: Set[str],
drop: bool = False,
):
if drop:
# this means if any of the filter terms appear, we drop
return [
r for r in rules if not any(r.rule_id.endswith(x) for x in filter_endswith)
]

# this means if any if the filter terms appear, we keep
return [r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith)]


# helper function to get the hidden activation function str
def get_hidden_activation_fn_key(config: PretrainedConfig):
for key in KEY_HIDDEN_ACTIVATIONS:
Expand Down
20 changes: 20 additions & 0 deletions plugins/fused-ops-and-kernels/src/fms_acceleration_foak/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from typing import List, Set

# Third Party
from accelerate.utils import set_module_tensor_to_device
from fms_acceleration.model_patcher import ModelPatcherRule
from transformers.modeling_utils import is_fsdp_enabled
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -74,3 +78,19 @@ def _all_reduce_hook(grad):
# - this has to be done after all weight replacement happens
A.weight.register_hook(_all_reduce_hook)
B.weight.register_hook(_all_reduce_hook)


# helper function to filter rules
def filter_mp_rules(
rules: List[ModelPatcherRule],
filter_endswith: Set[str],
drop: bool = False,
):
if drop:
# this means if any of the filter terms appear, we drop
return [
r for r in rules if not any(r.rule_id.endswith(x) for x in filter_endswith)
]

# this means if any if the filter terms appear, we keep
return [r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith)]
2 changes: 1 addition & 1 deletion plugins/fused-ops-and-kernels/tests/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fms_acceleration.model_patcher import ModelPatcherRule

# First Party
from fms_acceleration_foak.models.utils import filter_mp_rules
from fms_acceleration_foak.utils import filter_mp_rules


def test_filter_mp_rules():
Expand Down
Loading