Skip to content

Commit bbd08be

Browse files
committed
formatted foak
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
1 parent d5546d7 commit bbd08be

File tree

6 files changed

+38
-26
lines changed

6 files changed

+38
-26
lines changed

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch
2525
import torch.distributed as dist
2626

27+
2728
# consider moving this somewhere else later
2829
def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin):
2930
"""
@@ -58,9 +59,20 @@ def _all_reduce_hook(grad):
5859
if not B.weight.is_cuda:
5960
set_module_tensor_to_device(B, "weight", "cuda")
6061

62+
6163
def register_foak_model_patch_rules(base_type):
62-
from fms_acceleration.model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel
63-
from .models import llama, mistral, mixtral # pylint: disable=import-outside-toplevel
64+
# Third Party
65+
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
66+
ModelPatcher,
67+
)
68+
69+
# Local
70+
from .models import ( # pylint: disable=import-outside-toplevel
71+
llama,
72+
mistral,
73+
mixtral,
74+
)
75+
6476
rules = [
6577
*llama.get_mp_rules(base_type),
6678
*mistral.get_mp_rules(base_type),
@@ -69,6 +81,7 @@ def register_foak_model_patch_rules(base_type):
6981
for _rule in rules:
7082
ModelPatcher.register(_rule)
7183

84+
7285
class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin):
7386

7487
# NOTE: may remove this when we have generic model rules
@@ -122,7 +135,7 @@ def augmentation(
122135
), "need to run in fp16 mixed precision or load model in fp16"
123136

124137
# wrapper function to register foak patches
125-
register_foak_model_patch_rules(base_type = self._base_layer)
138+
register_foak_model_patch_rules(base_type=self._base_layer)
126139
return model, modifiable_args
127140

128141
def get_callbacks_and_ready_for_train(

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,25 @@
1616
from functools import partial
1717

1818
# Third Party
19-
from transformers.models.llama.modeling_llama import (
20-
LlamaAttention,
21-
LlamaMLP,
22-
LlamaRMSNorm,
23-
)
2419
from fms_acceleration.model_patcher import (
2520
ModelPatcherRule,
2621
ModelPatcherTrigger,
2722
combine_functions,
2823
combine_triggers,
2924
)
25+
from transformers.models.llama.modeling_llama import (
26+
LlamaAttention,
27+
LlamaMLP,
28+
LlamaRMSNorm,
29+
)
3030

3131
# Local
3232
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
3333
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3434
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
3535
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
3636

37+
3738
def get_mp_rules(base_type: str):
3839
"""
3940
Function to access all patch rules in this module.
@@ -125,5 +126,5 @@ def get_mp_rules(base_type: str):
125126
fast_rope_embedding,
126127
None,
127128
),
128-
)
129+
),
129130
]

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,25 @@
1616
from functools import partial
1717

1818
# Third Party
19-
from transformers.models.mistral.modeling_mistral import (
20-
MistralAttention,
21-
MistralMLP,
22-
MistralRMSNorm,
23-
)
2419
from fms_acceleration.model_patcher import (
2520
ModelPatcherRule,
2621
ModelPatcherTrigger,
2722
combine_functions,
2823
combine_triggers,
2924
)
30-
25+
from transformers.models.mistral.modeling_mistral import (
26+
MistralAttention,
27+
MistralMLP,
28+
MistralRMSNorm,
29+
)
3130

3231
# Local
3332
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
3433
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3534
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
3635
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
3736

37+
3838
def get_mp_rules(base_type):
3939
"""
4040
Function to access all patch rules in this module.
@@ -117,5 +117,5 @@ def get_mp_rules(base_type):
117117
fast_rope_embedding,
118118
None,
119119
),
120-
)
120+
),
121121
]

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,24 @@
1616
from functools import partial
1717

1818
# Third Party
19-
from transformers.models.mixtral.modeling_mixtral import (
20-
MixtralAttention,
21-
MixtralRMSNorm,
22-
)
2319
from fms_acceleration.model_patcher import (
2420
ModelPatcherRule,
2521
ModelPatcherTrigger,
2622
combine_functions,
2723
combine_triggers,
2824
)
25+
from transformers.models.mixtral.modeling_mixtral import (
26+
MixtralAttention,
27+
MixtralRMSNorm,
28+
)
2929

3030
# Local
3131
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
3232
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3333
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
34-
3534
from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
3635

36+
3737
def get_mp_rules(base_type):
3838
"""
3939
Function to access all patch rules in this module.
@@ -100,5 +100,5 @@ def get_mp_rules(base_type):
100100
fast_rope_embedding,
101101
None,
102102
),
103-
)
103+
),
104104
]

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55

66
# Third Party
7+
from fms_acceleration.model_patcher import ModelPatcherTrigger
78
import torch
89

910
# Local
@@ -16,7 +17,6 @@
1617
from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_mlp as fused_op_mlp_gptq
1718
from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_o_v2 as fused_op_o_gptq
1819
from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq
19-
from fms_acceleration.model_patcher import ModelPatcherTrigger
2020

2121
KEY_QKV = "qkv"
2222
KEY_O = "o"

plugins/fused-ops-and-kernels/tests/test_fused_ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33
from itertools import product
44

55
# Third Party
6+
from fms_acceleration.model_patcher import patch_model
67
from peft import LoraConfig
78
from transformers import AutoConfig
89
from transformers.models.llama.modeling_llama import LlamaAttention
910
from transformers.utils.import_utils import _is_package_available
1011
import pytest # pylint: disable=import-error
1112
import torch
1213

13-
# First Party
14-
from fms_acceleration.model_patcher import patch_model
15-
1614
BNB = "bitsandbytes"
1715
GPTQ = "auto_gptq"
1816

0 commit comments

Comments
 (0)