Skip to content

Commit 76e0cd5

Browse files
authored
fix moe quant (#3478)
1 parent b188440 commit 76e0cd5

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

swift/llm/export/quant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def get_block_name_to_quantize(self, model: nn.Module) -> Optional[str]:
180180

181181
module_lists = []
182182
for n, m in model.named_modules():
183-
if isinstance(m, nn.ModuleList) and len(m) >= 10:
183+
if (isinstance(m, (nn.ModuleList, nn.Sequential)) and len(m) >= 10
184+
and 'mlp' not in m[0].__class__.__name__.lower()): # fix moe
184185
module_lists.append((n, m))
185186
if module_lists:
186187
module_list = max(module_lists, key=lambda x: len(x[1]))

swift/llm/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def find_module_list(model) -> Optional[nn.ModuleList]:
9090
for m in model.modules():
9191
if hasattr(m, 'gradient_checkpointing') or m.__class__.__name__ == 'CheckpointWrapper':
9292
return
93-
if isinstance(m, (nn.ModuleList, nn.Sequential)) and len(m) >= 10:
93+
if (isinstance(m, (nn.ModuleList, nn.Sequential)) and len(m) >= 10
94+
and 'mlp' not in m[0].__class__.__name__.lower()): # fix moe
9495
module_lists.append(m)
9596
if module_lists:
9697
return max(module_lists, key=lambda x: len(x))

0 commit comments

Comments
 (0)