Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ You can use the `MOE_PEFT_EXECUTOR_TYPE` environment variable to force MoE-PEFT
| ✓ | [LLaMA 3.x](https://huggingface.co/meta-llama) | 3B/8B/70B |
| ✓ | [Yi 1/1.5](https://huggingface.co/01-ai) | 6B/9B/34B |
| ✓ | [TinyLLaMA](https://huggingface.co/TinyLlama) | 1.1B |
| ✓ | [Qwen 1.5/2.x](https://huggingface.co/Qwen) | 0.5B ~ 72B |
| ✓ | [Qwen 1.5/2/3](https://huggingface.co/Qwen) | 0.5B ~ 72B |
| ✓ | [Gemma](https://huggingface.co/google) | 2B/7B |
| ✓ | [Gemma 2](https://huggingface.co/google) | 9B/27B |
| ✓ | [Mistral](https://huggingface.co/mistralai) | 7B |
Expand Down Expand Up @@ -152,7 +152,7 @@ python launch.py run --base_model TinyLlama/TinyLlama_v1.1
python inference.py \
--base_model TinyLlama/TinyLlama_v1.1 \
--template alpaca \
--lora_weights ./casual_0
--lora_weights ./causal_0
```

For further detailed usage information, please refer to the `help` command:
Expand Down
6 changes: 3 additions & 3 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ def gen_config(
elif task_name not in moe_peft.tasks.task_dict:
try:
load_dataset(task_name)
except:
except Exception:
raise RuntimeError(f"Task name '{task_name}' not exist.")
lora_config["name"] = f"casual_{index}"
lora_config["task_name"] = "casual"
lora_config["name"] = f"causal_{index}"
lora_config["task_name"] = "causal"
lora_config["data"] = task_name
lora_config["prompt"] = "alpaca"
else:
Expand Down
24 changes: 24 additions & 0 deletions moe_peft/adapters/mixlora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class MixLoraConfig(LoraConfig):
num_experts_: int = None
act_fn_: Optional[Union[str, torch.nn.Module]] = None
# mixtral config
router_dyn_loss_coef_: float = None
entropy_index_: float = None
entropy_type_: str = None
entropy_eps_: float = None
top_k_: int = None
# dynamic config
top_p_: float = None
Expand Down Expand Up @@ -55,6 +59,20 @@ def check(self) -> "MixLoraConfig":
isinstance(self.act_fn_, str) and self.act_fn_ in ACT2FN
)
if self.routing_strategy_ == "mixlora":
assert (
isinstance(self.router_dyn_loss_coef_, float)
and self.router_dyn_loss_coef_ >= 0
)
assert (
isinstance(self.entropy_index_, float)
and self.entropy_index_ > 0
and self.entropy_index_ <= 2.0
)
assert isinstance(self.entropy_type_, str) and self.entropy_type_ in [
"tsallis",
"renyi",
]
assert isinstance(self.entropy_eps_, float) and self.entropy_eps_ > 0
assert isinstance(self.top_k_, int) and self.top_k_ > 0
elif self.routing_strategy_ == "mixlora-dynamic":
assert (
Expand Down Expand Up @@ -90,6 +108,12 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig":
# left blank to automatically use the original act_fn of FFN
lora_config.act_fn_ = config.get("act_fn", None)
if lora_config.routing_strategy_ == "mixlora":
lora_config.router_dyn_loss_coef_ = config.get(
"router_dyn_loss_coef", 0.01
) # for training
lora_config.entropy_index_ = config.get("entropy_index", 1.4)
lora_config.entropy_type_ = config.get("entropy_type", "tsallis")
lora_config.entropy_eps_ = config.get("entropy_eps", 1e-5)
lora_config.router_init_range_ = config.get("router_init_range", 0.02)
lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
lora_config.top_k_ = config.get("top_k", 2)
Expand Down
77 changes: 61 additions & 16 deletions moe_peft/adapters/mixlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
import torch.nn.functional as F
from transformers.activations import ACT2FN

from moe_peft.common import LLMFeedForward, LLMModelInput, LLMMoeBlock, slice_tensor
from moe_peft.common import (
LLMFeedForward,
LLMModelInput,
LLMMoeBlock,
renyi_entropy,
slice_tensor,
tsallis_entropy,
)

from .config import MixLoraConfig

Expand All @@ -31,15 +38,35 @@ def _mixlora_compatible_forward(

def _mixtral_load_balancing_loss_func(
gate_logits: torch.Tensor,
num_experts: int,
top_k: int,
adapter_config: MixLoraConfig,
attention_mask: Optional[torch.Tensor] = None,
) -> float:
routing_weights = torch.nn.functional.softmax(gate_logits, dim=-1)

_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
# Entropy Loss
if adapter_config.entropy_type_ == "tsallis":
router_entropy = tsallis_entropy(
p=routing_weights,
q=adapter_config.entropy_index_,
eps=adapter_config.entropy_eps_,
)
elif adapter_config.entropy_type_ == "renyi":
router_entropy = renyi_entropy(
p=routing_weights,
a=adapter_config.entropy_index_,
eps=adapter_config.entropy_eps_,
)
else:
raise NotImplementedError()

entropy_loss = router_entropy.mean()

# Load Balance Loss
_, selected_experts = torch.topk(routing_weights, adapter_config.top_k_, dim=-1)

expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
expert_mask = torch.nn.functional.one_hot(
selected_experts, adapter_config.num_experts_
)

if attention_mask is None:
# Compute the percentage of tokens routed to each experts
Expand All @@ -55,9 +82,15 @@ def _mixtral_load_balancing_loss_func(
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand(
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
(
num_hidden_layers,
batch_size,
sequence_length,
adapter_config.top_k_,
adapter_config.num_experts_,
)
)
.reshape(-1, top_k, num_experts)
.reshape(-1, adapter_config.top_k_, adapter_config.num_experts_)
.to(routing_weights.device)
)

Expand All @@ -69,8 +102,15 @@ def _mixtral_load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.expand(
(
num_hidden_layers,
batch_size,
sequence_length,
adapter_config.num_experts_,
)
)
.reshape(-1, adapter_config.num_experts_)
.to(routing_weights.device)
)

Expand All @@ -79,20 +119,25 @@ def _mixtral_load_balancing_loss_func(
routing_weights * router_per_expert_attention_mask, dim=0
) / torch.sum(router_per_expert_attention_mask, dim=0)

overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
load_balance_loss = (
torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
* adapter_config.num_experts_
)

return (
adapter_config.router_dyn_loss_coef_ * entropy_loss
+ adapter_config.router_aux_loss_coef_ * load_balance_loss
)


class MixtralRouterLoss(torch.nn.Module):
def __init__(self, config: MixLoraConfig) -> None:
super().__init__()
self.aux_loss_coef = config.router_aux_loss_coef_
self.experts = config.num_experts_
self.topk = config.top_k_
self.adapter_config = config

def forward(self, gate_logits, attention_mask) -> torch.Tensor:
return self.aux_loss_coef * _mixtral_load_balancing_loss_func(
gate_logits, self.experts, self.topk, attention_mask
return _mixtral_load_balancing_loss_func(
gate_logits, self.adapter_config, attention_mask
)


Expand Down
14 changes: 13 additions & 1 deletion moe_peft/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
LLMOutput,
)
from .attention import (
ATTENTION_FUNCTIONS,
eager_attention_forward,
flash_attention_forward,
prepare_4d_causal_attention_mask,
Expand Down Expand Up @@ -47,13 +48,21 @@
from .lora_linear import Linear, Lora, get_range_tensor

# MoEs
from .moe_utils import collect_plugin_router_logtis, slice_tensor, unpack_router_logits
from .moe_utils import (
collect_plugin_router_logtis,
slice_tensor,
unpack_router_logits,
tsallis_entropy,
shannon_entropy,
renyi_entropy,
)
from .rope import ROPE_INIT_FUNCTIONS

__all__ = [
"prepare_4d_causal_attention_mask",
"eager_attention_forward",
"flash_attention_forward",
"ATTENTION_FUNCTIONS",
"LLMCache",
"DynamicCache",
"HybridCache",
Expand All @@ -65,6 +74,9 @@
"CheckpointRecomputeFunction",
"CHECKPOINT_CLASSES",
"FeedForward",
"tsallis_entropy",
"shannon_entropy",
"renyi_entropy",
"slice_tensor",
"unpack_router_logits",
"collect_plugin_router_logtis",
Expand Down
61 changes: 53 additions & 8 deletions moe_peft/common/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,74 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
)

def get_max_length(self) -> Optional[int]:
return self.get_max_cache_shape()

def get_max_cache_shape(self) -> Optional[int]:
raise NotImplementedError(
"Make sure to implement `get_max_length` in a subclass."
"Make sure to implement `get_max_cache_shape` in a subclass."
)

def get_usable_length(
self, new_seq_length: int, layer_idx: Optional[int] = 0
) -> int:
max_length = self.get_max_length()
max_length = self.get_max_cache_shape()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length

def reorder_cache(self, beam_idx: torch.LongTensor):
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
0, beam_idx.to(device)
# Skip empty lists (used by DynamicCache for skipped layers)
key_cache_item = self.key_cache[layer_idx]
is_empty_list = (
isinstance(key_cache_item, list) and len(key_cache_item) == 0
)
if not is_empty_list:
device = key_cache_item.device
self.key_cache[layer_idx] = key_cache_item.index_select(
0, beam_idx.to(device)
)

value_cache_item = self.value_cache[layer_idx]
is_empty_list = (
isinstance(value_cache_item, list) and len(value_cache_item) == 0
)
if not is_empty_list:
device = value_cache_item.device
self.value_cache[layer_idx] = value_cache_item.index_select(
0, beam_idx.to(device)
)

def select_batch(self, batch_idx: torch.LongTensor):
"""
Select a subset of the batch dimension across all cache layers.
This mirrors `reorder_cache` behavior but is intended for batch filtering
(e.g., top-k selection in generation). Layers with empty lists are skipped.

Note: For static or sliding-window caches, this replaces the cached tensor
reference for the layer with an indexed view, consistent with `reorder_cache`.
"""
for layer_idx in range(len(self.key_cache)):
key_cache_item = self.key_cache[layer_idx]
is_empty_list = (
isinstance(key_cache_item, list) and len(key_cache_item) == 0
)
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
0, beam_idx.to(device)
if not is_empty_list:
device = key_cache_item.device
self.key_cache[layer_idx] = key_cache_item.index_select(
0, batch_idx.to(device)
)

value_cache_item = self.value_cache[layer_idx]
is_empty_list = (
isinstance(value_cache_item, list) and len(value_cache_item) == 0
)
if not is_empty_list:
device = value_cache_item.device
self.value_cache[layer_idx] = value_cache_item.index_select(
0, batch_idx.to(device)
)


class LLMAttention(metaclass=ABCMeta):
Expand Down
Loading