Skip to content

Commit 2c202ef

Browse files
committed
some fixes
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent 0aab2c6 commit 2c202ef

File tree

4 files changed

+25
-181
lines changed

4 files changed

+25
-181
lines changed

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py

Lines changed: 0 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,6 @@
3030
import torch.nn.functional as F
3131
from torch.nn import CrossEntropyLoss
3232
from transformers.modeling_outputs import CausalLMOutputWithPast
33-
from transformers.models.llama.modeling_llama import (
34-
_CONFIG_FOR_DOC,
35-
LLAMA_INPUTS_DOCSTRING,
36-
)
37-
from transformers.models.mixtral.modeling_mixtral import (
38-
_CONFIG_FOR_DOC,
39-
MIXTRAL_INPUTS_DOCSTRING,
40-
)
41-
from transformers.modeling_outputs import (
42-
MoeCausalLMOutputWithPast,
43-
MoeModelOutputWithPast,
44-
)
45-
from transformers.utils import (
46-
add_start_docstrings_to_model_forward,
47-
replace_return_docstrings,
48-
)
4933

5034
from .cross_entropy import (
5135
element_mul_kernel,
@@ -297,11 +281,6 @@ def forward(self, lin_weight, _input, target, bias=None):
297281
self.reduction,
298282
)
299283

300-
# TODO: how to add diff docstrings for diff model types? what if the loss functions aren't the same across models?
301-
# @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
302-
@replace_return_docstrings(
303-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
304-
)
305284
def lce_forward(
306285
self,
307286
input_ids: torch.LongTensor = None,
@@ -435,143 +414,4 @@ def lce_forward(
435414
past_key_values=outputs.past_key_values,
436415
hidden_states=outputs.hidden_states,
437416
attentions=outputs.attentions,
438-
)
439-
440-
# TODO: is adding a separate copy of lce_forward() the right path or should the additional logic for Moe models be in the single lce_forward?
441-
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
442-
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
443-
# Ignore copy
444-
def lce_forward_mixtral(
445-
self,
446-
input_ids: torch.LongTensor = None,
447-
attention_mask: Optional[torch.Tensor] = None,
448-
position_ids: Optional[torch.LongTensor] = None,
449-
past_key_values: Optional[List[torch.FloatTensor]] = None,
450-
inputs_embeds: Optional[torch.FloatTensor] = None,
451-
labels: Optional[torch.LongTensor] = None,
452-
use_cache: Optional[bool] = None,
453-
output_attentions: Optional[bool] = None,
454-
output_hidden_states: Optional[bool] = None,
455-
output_router_logits: Optional[bool] = None,
456-
return_dict: Optional[bool] = None,
457-
cache_position: Optional[torch.LongTensor] = None,
458-
num_logits_to_keep: int = 0,
459-
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
460-
r"""
461-
Args:
462-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
463-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
464-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
465-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
466-
467-
num_logits_to_keep (`int`, *optional*):
468-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
469-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
470-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
471-
472-
Returns:
473-
474-
Example:
475-
476-
```python
477-
>>> from transformers import AutoTokenizer, MixtralForCausalLM
478-
479-
>>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
480-
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
481-
482-
>>> prompt = "Hey, are you conscious? Can you talk to me?"
483-
>>> inputs = tokenizer(prompt, return_tensors="pt")
484-
485-
>>> # Generate
486-
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
487-
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
488-
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
489-
```"""
490-
491-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
492-
output_router_logits = (
493-
output_router_logits if output_router_logits is not None else self.config.output_router_logits
494-
)
495-
496-
output_hidden_states = (
497-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
498-
)
499-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
500-
501-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
502-
outputs = self.model(
503-
input_ids=input_ids,
504-
attention_mask=attention_mask,
505-
position_ids=position_ids,
506-
past_key_values=past_key_values,
507-
inputs_embeds=inputs_embeds,
508-
use_cache=use_cache,
509-
output_attentions=output_attentions,
510-
output_hidden_states=output_hidden_states,
511-
output_router_logits=output_router_logits,
512-
return_dict=return_dict,
513-
cache_position=cache_position,
514-
)
515-
516-
hidden_states = outputs[0]
517-
518-
loss = None
519-
logits = None
520-
521-
# patch change
522-
if self.training and (labels is not None):
523-
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
524-
shift_labels = labels[..., 1:].contiguous()
525-
526-
# flatten tokens
527-
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
528-
shift_labels = shift_labels.view(-1)
529-
530-
lce = LigerFusedLinearCrossEntropyLoss()
531-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
532-
else:
533-
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
534-
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
535-
536-
if labels is not None:
537-
# Upcast to float if we need to compute the loss to avoid potential precision issues
538-
logits = logits.float()
539-
# Shift so that tokens < n predict n
540-
shift_logits = logits[..., :-1, :].contiguous()
541-
shift_labels = labels[..., 1:].contiguous()
542-
# Flatten the tokens
543-
loss_fct = CrossEntropyLoss()
544-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
545-
shift_labels = shift_labels.view(-1)
546-
# Enable model parallelism
547-
shift_labels = shift_labels.to(shift_logits.device)
548-
loss = loss_fct(shift_logits, shift_labels)
549-
550-
# TODO: unique differing part to mixtral model forward
551-
aux_loss = None
552-
if output_router_logits:
553-
aux_loss = load_balancing_loss_func(
554-
outputs.router_logits if return_dict else outputs[-1],
555-
self.num_experts,
556-
self.num_experts_per_tok,
557-
attention_mask,
558-
)
559-
# TODO: should this loss manipulation be indented in?? or should it be added to even the liger loss?
560-
if labels is not None:
561-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
562-
563-
if not return_dict:
564-
output = (logits,) + outputs[1:]
565-
if output_router_logits:
566-
output = (aux_loss,) + output
567-
return (loss,) + output if loss is not None else output
568-
569-
return MoeCausalLMOutputWithPast(
570-
loss=loss,
571-
aux_loss=aux_loss,
572-
logits=logits,
573-
past_key_values=outputs.past_key_values,
574-
hidden_states=outputs.hidden_states,
575-
attentions=outputs.attentions,
576-
router_logits=outputs.router_logits,
577417
)

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
combine_triggers,
2424
)
2525
from transformers.models.mixtral.modeling_mixtral import (
26-
MixtralForCausalLM,
2726
MixtralAttention,
2827
MixtralRMSNorm,
2928
)
@@ -32,7 +31,6 @@
3231
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
3332
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3433
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
35-
from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward_mixtral
3634
from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
3735

3836

@@ -95,11 +93,6 @@ def get_mp_rules(base_type):
9593
"transformers.models.mixtral.modeling_mixtral",
9694
),
9795
),
98-
ModelPatcherRule(
99-
rule_id="mixtral-fused-lce",
100-
trigger=ModelPatcherTrigger(check=MixtralForCausalLM),
101-
forward=lce_forward_mixtral,
102-
),
10396
ModelPatcherRule(
10497
rule_id="mixtral-rope",
10598
import_and_maybe_reload=(

scripts/benchmarks/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset):
723723

724724
if (
725725
not args.run_only_scenarios
726-
and scenarios.slow
726+
and scenario.slow
727727
):
728728
# unfiltered runs omit all "slow" marked scenarios
729729
print(f"Skipping slow scenario '{_scn_name}' beacuse run_only_scenarios=None.")

scripts/benchmarks/scenarios-liger.yaml

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,18 @@
3838
scenarios:
3939
- name: full-finetuning
4040
framework_config:
41-
-
4241
- foak-fast-kernels
4342
- foak-fast-kernels-liger
4443
arguments:
4544
learning_rate: 2e-5
4645
model_name_or_path:
47-
- 'bigcode/gpt_bigcode-santacoder'
48-
- 'mistralai/Mistral-7B-v0.1'
49-
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
50-
- 'NousResearch/Llama-2-70b-hf'
46+
# - 'mistralai/Mistral-7B-v0.1'
47+
- 'meta-llama/Meta-Llama-3-8B'
5148
torch_dtype: bfloat16
5249
bf16: True
5350

5451
- name: standard-peft
5552
framework_config:
56-
-
5753
- foak-fast-kernels
5854
- foak-fast-kernels-liger
5955
arguments:
@@ -66,13 +62,29 @@ scenarios:
6662
lora_dropout: 0.1
6763
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
6864
model_name_or_path:
69-
- 'mistralai/Mistral-7B-v0.1'
70-
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
71-
- 'NousResearch/Llama-2-70b-hf'
65+
# - 'mistralai/Mistral-7B-v0.1'
66+
- 'meta-llama/Meta-Llama-3-8B'
67+
68+
- name: accelerated-peft-bnb
69+
framework_config:
70+
- accelerated-peft-bnb-foak
71+
- accelerated-peft-bnb-foak-liger
72+
arguments:
73+
bf16: True
74+
learning_rate: 2e-4
75+
torch_dtype: bfloat16
76+
peft_method: lora
77+
r: 16
78+
lora_alpha: 16
79+
lora_dropout: 0.1
80+
per_device_train_batch_size:
81+
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
82+
model_name_or_path:
83+
# - 'mistralai/Mistral-7B-v0.1'
84+
- 'meta-llama/Meta-Llama-3-8B'
7285

7386
- name: accelerated-peft-gptq
7487
framework_config:
75-
- accelerated-peft-autogptq
7688
- accelerated-peft-autogptq-foak
7789
- accelerated-peft-autogptq-foak-liger
7890
arguments:
@@ -85,6 +97,5 @@ scenarios:
8597
lora_dropout: 0.1
8698
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
8799
model_name_or_path:
88-
- 'TheBloke/Mistral-7B-v0.1-GPTQ'
89-
- 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
90-
- 'TheBloke/Llama-2-70B-GPTQ'
100+
# - 'TheBloke/Mistral-7B-v0.1-GPTQ'
101+
- 'TechxGenus/Meta-Llama-3-8B-GPTQ'

0 commit comments

Comments
 (0)