diff --git a/lmdeploy/pytorch/envs.py b/lmdeploy/pytorch/envs.py index 37ef2859d7..aa510abe1c 100644 --- a/lmdeploy/pytorch/envs.py +++ b/lmdeploy/pytorch/envs.py @@ -98,6 +98,12 @@ def _patched_get_env( # logging log_file = os.getenv('LMDEPLOY_LOG_FILE', None) + # dump expert distribution + dump_expert_distribution = env_to_bool('LMDEPLOY_DUMP_EXPERT_DISTRIBUTION', False) + expert_dump_dir = os.getenv('LMDEPLOY_EXPERT_DUMP_DIR', '/tmp/lmdeploy/expert_distribution') + expert_dump_frequency = env_to_int('LMDEPLOY_EXPERT_DUMP_FREQUENCY', 5) + expert_dump_rank = env_to_int('LMDEPLOY_EXPERT_DUMP_RANK', 0) + def get_all_envs(): """Get all environment variables.""" diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 7389946131..7cab781d16 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -22,6 +22,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.expert_distribution_recorder import ExpertsDistributionRecorder # microbatch @@ -663,10 +664,12 @@ def forward(self, hidden_states: torch.Tensor): class DeepseekV2MoE(nn.Module): """Deepseek v2 MoE.""" + recorder = ExpertsDistributionRecorder() def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: torch.device = None): super().__init__() quantization_config = getattr(config, 'quantization_config', None) + self.layer_idx = layer_idx self.hidden_dim = config.hidden_size self.ffn_dim = config.moe_intermediate_size self.num_experts = config.n_routed_experts @@ -738,6 +741,7 @@ def forward(self, hidden_states: torch.Tensor): if self._all_reduce: dist.all_reduce(out_states) + DeepseekV2MoE.recorder.record(topk_ids, self.layer_idx, self.num_experts) return out_states diff --git a/lmdeploy/pytorch/models/qwen3_moe.py b/lmdeploy/pytorch/models/qwen3_moe.py index c0e281838e..6842975293 100644 --- a/lmdeploy/pytorch/models/qwen3_moe.py +++ b/lmdeploy/pytorch/models/qwen3_moe.py @@ -15,6 +15,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.expert_distribution_recorder import ExpertsDistributionRecorder class Qwen3MoeAttention(nn.Module): @@ -170,6 +171,7 @@ def forward(self, x): class Qwen3MoeSparseMoeBlock(nn.Module): """Moe block.""" + recorder = ExpertsDistributionRecorder() def __init__(self, config: PretrainedConfig, @@ -235,6 +237,8 @@ def forward(self, hidden_states: torch.Tensor): ) out_states = out_states.reshape(batch_size, sequence_length, -1) + + Qwen3MoeSparseMoeBlock.recorder.record(topk_ids, self.layer_idx, self.num_experts) return out_states diff --git a/lmdeploy/pytorch/models/utils/expert_distribution_recorder.py b/lmdeploy/pytorch/models/utils/expert_distribution_recorder.py new file mode 100644 index 0000000000..56823c2d09 --- /dev/null +++ b/lmdeploy/pytorch/models/utils/expert_distribution_recorder.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# adapted from https://github.com/DeepLink-org/dlBLAS/blob/main/dlblas/layers/moe/experts_distribution_recorder.py + +import os +from datetime import datetime + +import torch +import torch.distributed as dist + +from lmdeploy.pytorch import envs +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + + +class _NoOpExpertsDistributionRecorder: + """A no-op version of the recorder that does nothing.""" + + def __init__(self, *args, **kwargs): + pass + + def record(self, *args, **kwargs): + pass + + +class _ExpertsDistributionRecorderImpl: + """The actual implementation of the recorder.""" + + def __init__(self): + self.output_dir = envs.expert_dump_dir + self.dispatch_count = {} + self.accum_token_counts = {} + self.global_token_counts = {} + self.last_dump_minute = -1 + self.dump_frequency = envs.expert_dump_frequency + self.dump_rank = envs.expert_dump_rank + + def map_to_sorted_2d_array(self, data_map): + sorted_keys = sorted(data_map.keys(), key=lambda k: int(k.split('_')[0])) + data_2d_array = [data_map[key].cpu().tolist() for key in sorted_keys] + return data_2d_array + + def record(self, topk_ids, layer_index, num_experts): + key = f'{layer_index}_{num_experts}' + if key not in self.dispatch_count: + self.dispatch_count[key] = 0 + self.dispatch_count[key] += 1 + if key not in self.accum_token_counts: + self.accum_token_counts[key] = torch.zeros(num_experts, dtype=torch.int64, device='cuda') + topk_ids_flat = topk_ids.view(-1) + step_local_counts = torch.bincount(topk_ids_flat, minlength=num_experts) + self.accum_token_counts[key] += step_local_counts + global_token_counts_tmp = self.accum_token_counts[key].clone() + if dist.is_initialized(): + dist.all_reduce(global_token_counts_tmp, op=dist.ReduceOp.SUM) + self.global_token_counts[key] = global_token_counts_tmp + rank = dist.get_rank() if dist.is_initialized() else 0 + now = datetime.now() + if (rank == self.dump_rank and now.minute % self.dump_frequency == 0 and now.minute != self.last_dump_minute): + self.last_dump_minute = now.minute + global_list = self.map_to_sorted_2d_array(self.global_token_counts) + step = self.dispatch_count[key] + os.makedirs(self.output_dir, exist_ok=True) + token_counts_file_name = f'rank{rank}_step{step}_experts_counts.json' + filepath = os.path.join(self.output_dir, token_counts_file_name) + with open(filepath, 'w') as f: + import json + json.dump(global_list, f, indent=2) + logger.info(f'[EPLB] Experts distribution dumped to {filepath}') + + +class ExpertsDistributionRecorder: + """Factory class that returns a real or no-op recorder.""" + + def __new__(cls, *args, **kwargs): + if envs.dump_expert_distribution: + logger.info('Expert distribution recorder is enabled.') + return _ExpertsDistributionRecorderImpl(*args, **kwargs) + else: + logger.info('Expert distribution recorder is disabled.') + return _NoOpExpertsDistributionRecorder(*args, **kwargs)