Skip to content

Commit 8bb871b

Browse files
authored
fix: deepspeed with context parallel (#3220)
1 parent 87565ec commit 8bb871b

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

src/axolotl/monkeypatch/transformers/trainer_context_parallel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
LOG = get_logger(__name__)
1414

1515
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
16-
PATCHED_GUARD = (
17-
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
18-
)
16+
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
1917

2018

2119
def patch_prepare_context_parallel_inputs() -> None:

0 commit comments

Comments
 (0)