Skip to content

Commit 09bd1b9

Browse files
authored
Verify sequence parallel for seq_cls (#7240)
1 parent b47f451 commit 09bd1b9

File tree

1 file changed

+0
-9
lines changed

1 file changed

+0
-9
lines changed

swift/trainers/mixin.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -712,15 +712,6 @@ def _get_hook_target_model(task_type_: str) -> nn.Module:
712712
return get_lm_head_model(self.model, model_meta=self.model.model_meta)
713713
return get_llm_model(self.model, model_meta=self.model.model_meta)
714714

715-
# Temporary guardrail:
716-
# RP>1 implies ring-attention/zigzag workflow which requires careful task-specific pooling/restore.
717-
# We have not fully validated seq_cls/reranker/embedding under RP>1 yet, so fail fast.
718-
if sp_enabled and task_type in {'seq_cls', 'reranker', 'embedding', 'generative_reranker'}:
719-
from swift.trainers.sequence_parallel import sequence_parallel
720-
rp_world_size = getattr(sequence_parallel, 'rp_world_size', None)
721-
if isinstance(rp_world_size, int) and rp_world_size > 1:
722-
raise NotImplementedError(f'task_type={task_type} with ring-attention is not supported yet. ')
723-
724715
# --- seq_cls / reranker / generative_reranker unified pipeline ---
725716
if task_type in {'seq_cls', 'reranker', 'generative_reranker', 'embedding'}:
726717
llm_model = _get_hook_target_model(task_type)

0 commit comments

Comments
 (0)