Skip to content

Commit 5e1de0f

Browse files
committed
[megatron] Fix SP & LoRA (#5704)
1 parent 3562266 commit 5e1de0f

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

swift/llm/model/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,10 @@ def git_clone_github(github_url: str,
324324
local_repo_name = github_url.rsplit('/', 1)[1]
325325
github_url = f'{github_url}.git'
326326
local_repo_path = os.path.join(git_cache_dir, local_repo_name)
327-
with safe_ddp_context(None, use_barrier=True):
328-
if not is_local_master():
329-
return local_repo_path
327+
with safe_ddp_context('git_clone', use_barrier=True):
330328
repo_existed = os.path.exists(local_repo_path)
329+
if not is_local_master() and repo_existed:
330+
return local_repo_path
331331
if repo_existed:
332332
command = ['git', '-C', local_repo_path, 'fetch']
333333
subprocess_run(command)

swift/megatron/tuners/lora.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TERowParallelGroupedLinear, TERowParallelLinear)
1616
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
1717
from megatron.core.parallel_state import get_expert_tensor_parallel_world_size, get_tensor_model_parallel_world_size
18+
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region
1819
from megatron.core.transformer.mlp import apply_swiglu_sharded_factory
1920
from megatron.core.transformer.module import MegatronModule
2021
from megatron.core.transformer.moe.router import TopKRouter
@@ -58,6 +59,7 @@ def __init__(
5859
self.fan_in_fan_out = fan_in_fan_out
5960
self._active_adapter = adapter_name
6061
self.is_expert = getattr(base_layer, 'is_expert', False)
62+
self.sequence_parallel = getattr(base_layer, 'sequence_parallel', False)
6163
if self.is_expert:
6264
self.tp_size = get_expert_tensor_parallel_world_size()
6365
else:
@@ -189,6 +191,8 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w
189191
lora.ub_overlap_ag_dgrad = False
190192
lora.ub_overlap_ag_fprop = False
191193
lora.ub_overlap_rs_dgrad = False
194+
lora_a.sequence_parallel = False
195+
lora_b.sequence_parallel = False
192196
self.lora_A[adapter_name] = lora_a
193197
self.lora_B[adapter_name] = lora_b
194198
if hasattr(self, 'lora_bias'):
@@ -287,6 +291,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
287291
else:
288292
raise ValueError(f'Unsupported base layer type: {type(self.base_layer)}')
289293
if not isinstance(self.base_layer, TopKRouter) and not self.disable_adapters and not self.merged:
294+
if self.sequence_parallel and self.base_layer.parallel_mode == 'column':
295+
x = gather_from_sequence_parallel_region(x)
290296
for active_adapter in self.active_adapters:
291297
if active_adapter not in self.lora_A.keys():
292298
continue
@@ -306,7 +312,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
306312
if isinstance(lora_result, tuple):
307313
lora_result = lora_result[0]
308314
lora_result = lora_result * scaling
309-
315+
if self.sequence_parallel and self.base_layer.parallel_mode == 'row':
316+
lora_result = scatter_to_sequence_parallel_region(lora_result)
310317
result = result + lora_result
311318

312319
result = result.to(previous_dtype)

0 commit comments

Comments
 (0)