Skip to content

Commit bfb8a8f

Browse files
committed
run formatter
Signed-off-by: Anh Uong <[email protected]>
1 parent 168f170 commit bfb8a8f

File tree

7 files changed

+92
-78
lines changed

7 files changed

+92
-78
lines changed

plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/multipack_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2121
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2222
SOFTWARE.
23-
taken from https://github.com/imoneoi/multipack_sampler with some modifications
23+
taken from https://github.com/imoneoi/multipack_sampler with some modifications
2424
taken from https://github.com/instructlab/training/blob/main/src/instructlab/training/multipack_sampler.py
2525
"""
2626

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
KEY_QKV,
4141
build_lora_fused_ops,
4242
get_hidden_activation_fn_key,
43-
trigger_fused_ops,
4443
get_transformers_version,
44+
trigger_fused_ops,
4545
)
4646

4747

@@ -127,22 +127,24 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
127127
),
128128
),
129129
*[
130-
ModelPatcherRule(
131-
rule_id="granite-custom-loss",
132-
trigger=ModelPatcherTrigger(
133-
check=replace_custom_loss_when_triggered(
134-
GraniteForCausalLM, custom_loss_type="granite-custom-loss"
135-
)
136-
),
137-
)
138-
if get_transformers_version() >= "4.46" else
139-
ModelPatcherRule(
140-
rule_id="granite-cross-ent",
141-
import_and_maybe_reload=(
142-
"torch.nn.CrossEntropyLoss",
143-
FastCrossEntropyLoss,
144-
"transformers.models.granite.modeling_granite",
145-
),
130+
(
131+
ModelPatcherRule(
132+
rule_id="granite-custom-loss",
133+
trigger=ModelPatcherTrigger(
134+
check=replace_custom_loss_when_triggered(
135+
GraniteForCausalLM, custom_loss_type="granite-custom-loss"
136+
)
137+
),
138+
)
139+
if get_transformers_version() >= "4.46"
140+
else ModelPatcherRule(
141+
rule_id="granite-cross-ent",
142+
import_and_maybe_reload=(
143+
"torch.nn.CrossEntropyLoss",
144+
FastCrossEntropyLoss,
145+
"transformers.models.granite.modeling_granite",
146+
),
147+
)
146148
)
147149
],
148150
ModelPatcherRule(

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
KEY_QKV,
4747
build_lora_fused_ops,
4848
get_hidden_activation_fn_key,
49-
trigger_fused_ops,
5049
get_transformers_version,
50+
trigger_fused_ops,
5151
)
5252

5353

@@ -127,22 +127,24 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
127127
forward=lce_forward,
128128
),
129129
*[
130-
ModelPatcherRule(
131-
rule_id="llama-custom-loss",
132-
trigger=ModelPatcherTrigger(
133-
check=replace_custom_loss_when_triggered(
134-
LlamaForCausalLM, custom_loss_type="llama-custom-loss"
135-
)
136-
),
137-
)
138-
if get_transformers_version() >= "4.46" else
139-
ModelPatcherRule(
140-
rule_id="llama-cross-ent",
141-
import_and_maybe_reload=(
142-
"torch.nn.CrossEntropyLoss",
143-
FastCrossEntropyLoss,
144-
"transformers.models.llama.modeling_llama",
145-
),
130+
(
131+
ModelPatcherRule(
132+
rule_id="llama-custom-loss",
133+
trigger=ModelPatcherTrigger(
134+
check=replace_custom_loss_when_triggered(
135+
LlamaForCausalLM, custom_loss_type="llama-custom-loss"
136+
)
137+
),
138+
)
139+
if get_transformers_version() >= "4.46"
140+
else ModelPatcherRule(
141+
rule_id="llama-cross-ent",
142+
import_and_maybe_reload=(
143+
"torch.nn.CrossEntropyLoss",
144+
FastCrossEntropyLoss,
145+
"transformers.models.llama.modeling_llama",
146+
),
147+
)
146148
)
147149
],
148150
# TODO: have a generic version of this rule

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
KEY_QKV,
4747
build_lora_fused_ops,
4848
get_hidden_activation_fn_key,
49-
trigger_fused_ops,
5049
get_transformers_version,
50+
trigger_fused_ops,
5151
)
5252

5353

@@ -119,22 +119,24 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
119119
),
120120
),
121121
*[
122-
ModelPatcherRule(
123-
rule_id="mistral-custom-loss",
124-
trigger=ModelPatcherTrigger(
125-
check=replace_custom_loss_when_triggered(
126-
MistralForCausalLM, custom_loss_type="mistral-custom-loss"
127-
)
128-
),
129-
)
130-
if get_transformers_version() >= "4.46" else
131-
ModelPatcherRule(
132-
rule_id="mistral-cross-ent",
133-
import_and_maybe_reload=(
134-
"torch.nn.CrossEntropyLoss",
135-
FastCrossEntropyLoss,
136-
"transformers.models.mistral.modeling_mistral",
137-
),
122+
(
123+
ModelPatcherRule(
124+
rule_id="mistral-custom-loss",
125+
trigger=ModelPatcherTrigger(
126+
check=replace_custom_loss_when_triggered(
127+
MistralForCausalLM, custom_loss_type="mistral-custom-loss"
128+
)
129+
),
130+
)
131+
if get_transformers_version() >= "4.46"
132+
else ModelPatcherRule(
133+
rule_id="mistral-cross-ent",
134+
import_and_maybe_reload=(
135+
"torch.nn.CrossEntropyLoss",
136+
FastCrossEntropyLoss,
137+
"transformers.models.mistral.modeling_mistral",
138+
),
139+
)
138140
)
139141
],
140142
ModelPatcherRule(

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
)
2525
from transformers.models.mixtral.modeling_mixtral import (
2626
MixtralAttention,
27-
MixtralRMSNorm,
2827
MixtralForCausalLM,
28+
MixtralRMSNorm,
2929
)
3030

3131
# Local
@@ -35,7 +35,13 @@
3535
)
3636
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3737
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
38-
from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops, get_transformers_version
38+
from .utils import (
39+
KEY_O,
40+
KEY_QKV,
41+
build_lora_fused_ops,
42+
get_transformers_version,
43+
trigger_fused_ops,
44+
)
3945

4046

4147
def get_mp_rules(base_type):
@@ -90,22 +96,24 @@ def get_mp_rules(base_type):
9096
),
9197
),
9298
*[
93-
ModelPatcherRule(
94-
rule_id="mixtral-custom-loss",
95-
trigger=ModelPatcherTrigger(
96-
check=replace_custom_loss_when_triggered(
97-
MixtralForCausalLM, custom_loss_type="mixtral-custom-loss"
98-
)
99-
),
100-
)
101-
if get_transformers_version() >= "4.46" else
102-
ModelPatcherRule(
103-
rule_id="mixtral-cross-ent",
104-
import_and_maybe_reload=(
105-
"torch.nn.CrossEntropyLoss",
106-
FastCrossEntropyLoss,
107-
"transformers.models.mixtral.modeling_mixtral",
108-
),
99+
(
100+
ModelPatcherRule(
101+
rule_id="mixtral-custom-loss",
102+
trigger=ModelPatcherTrigger(
103+
check=replace_custom_loss_when_triggered(
104+
MixtralForCausalLM, custom_loss_type="mixtral-custom-loss"
105+
)
106+
),
107+
)
108+
if get_transformers_version() >= "4.46"
109+
else ModelPatcherRule(
110+
rule_id="mixtral-cross-ent",
111+
import_and_maybe_reload=(
112+
"torch.nn.CrossEntropyLoss",
113+
FastCrossEntropyLoss,
114+
"transformers.models.mixtral.modeling_mixtral",
115+
),
116+
)
109117
)
110118
],
111119
ModelPatcherRule(

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def get_hidden_activation_fn_key(config: PretrainedConfig):
216216
f"architecture {config.architectures}."
217217
)
218218

219+
219220
def get_transformers_version():
220-
_, _transformers_version = _is_package_available("transformers", return_version=True)
221-
return _transformers_version
221+
_, _transformers_version = _is_package_available(
222+
"transformers", return_version=True
223+
)
224+
return _transformers_version

scripts/benchmarks/benchmark.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,7 @@ def __init__(
171171
) -> None:
172172

173173
self.dataset_split = datasets.load_dataset(
174-
dataset_name,
175-
split=dataset_split,
176-
**additional_dataset_kwargs
174+
dataset_name, split=dataset_split, **additional_dataset_kwargs
177175
)
178176

179177
self.kwargs = {
@@ -206,9 +204,8 @@ def prepare_dataset(
206204
)
207205
response_template = self.response_template
208206

209-
if (
210-
self.kwargs['tokenize']
211-
or (not self.kwargs['tokenize'] and self.kwargs['chat_template'])
207+
if self.kwargs["tokenize"] or (
208+
not self.kwargs["tokenize"] and self.kwargs["chat_template"]
212209
):
213210
tokenizer = AutoTokenizer.from_pretrained(model_name)
214211
# for now, if pad_token_id is None, will just do a replacement

0 commit comments

Comments
 (0)