1515 TERowParallelGroupedLinear , TERowParallelLinear )
1616from megatron .core .models .common .embeddings .language_model_embedding import LanguageModelEmbedding
1717from megatron .core .parallel_state import get_expert_tensor_parallel_world_size , get_tensor_model_parallel_world_size
18+ from megatron .core .tensor_parallel import gather_from_sequence_parallel_region , scatter_to_sequence_parallel_region
1819from megatron .core .transformer .mlp import apply_swiglu_sharded_factory
1920from megatron .core .transformer .module import MegatronModule
2021from megatron .core .transformer .moe .router import TopKRouter
@@ -58,6 +59,7 @@ def __init__(
5859 self .fan_in_fan_out = fan_in_fan_out
5960 self ._active_adapter = adapter_name
6061 self .is_expert = getattr (base_layer , 'is_expert' , False )
62+ self .sequence_parallel = getattr (base_layer , 'sequence_parallel' , False )
6163 if self .is_expert :
6264 self .tp_size = get_expert_tensor_parallel_world_size ()
6365 else :
@@ -189,6 +191,8 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w
189191 lora .ub_overlap_ag_dgrad = False
190192 lora .ub_overlap_ag_fprop = False
191193 lora .ub_overlap_rs_dgrad = False
194+ lora_a .sequence_parallel = False
195+ lora_b .sequence_parallel = False
192196 self .lora_A [adapter_name ] = lora_a
193197 self .lora_B [adapter_name ] = lora_b
194198 if hasattr (self , 'lora_bias' ):
@@ -287,6 +291,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
287291 else :
288292 raise ValueError (f'Unsupported base layer type: { type (self .base_layer )} ' )
289293 if not isinstance (self .base_layer , TopKRouter ) and not self .disable_adapters and not self .merged :
294+ if self .sequence_parallel and self .base_layer .parallel_mode == 'column' :
295+ x = gather_from_sequence_parallel_region (x )
290296 for active_adapter in self .active_adapters :
291297 if active_adapter not in self .lora_A .keys ():
292298 continue
@@ -306,7 +312,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
306312 if isinstance (lora_result , tuple ):
307313 lora_result = lora_result [0 ]
308314 lora_result = lora_result * scaling
309-
315+ if self .sequence_parallel and self .base_layer .parallel_mode == 'row' :
316+ lora_result = scatter_to_sequence_parallel_region (lora_result )
310317 result = result + lora_result
311318
312319 result = result .to (previous_dtype )
0 commit comments