Skip to content

Commit b6f4e20

Browse files
[BUG] Fix UlyssesSPAttentionHF.register_with_transformers() crash with PEFT models (#7737)
**Description** This PR fixes a crash in `UlyssesSPAttentionHF.register_with_transformers()` when a PEFT-wrapped model (e.g., `PeftModel`) is passed as the `model_name_or_path` argument. **The Issue** The function previously used an overly strict `isinstance(model_name_or_path, PreTrainedModel)` check. Since PEFT models do not subclass `PreTrainedModel` (though they forward to one), the check would fail. The logic then fell through to the `else` block, treating the model object as a string path and calling `AutoConfig.from_pretrained(model_name_or_path)`, which immediately raised a `TypeError` or `OSError`. **Changes** * Updated the logic to use duck-typing: if the input object has a `.config` attribute, we treat it as a model and access the configuration directly. * Hugging Face string paths (Hub IDs or local paths) continue to be handled by the fallback to `AutoConfig`. **Validation** Verified that: 1. PEFT-wrapped models now successfully register without crashing. 2. Standard `PreTrainedModel` objects still register correctly. 3. String paths successfully trigger `AutoConfig.from_pretrained` as expected. Fixes #7729 --------- Signed-off-by: Rakshit-gen <sisodiarakshit456@gmail.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
1 parent 377a0d1 commit b6f4e20

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

deepspeed/runtime/sequence_parallel/ulysses_sp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ def register_with_transformers(
389389
mpu.initialize_sequence_parallel(sequence_parallel_size=sequence_parallel_size)
390390

391391
from transformers import PreTrainedModel
392-
if isinstance(model_name_or_path, PreTrainedModel):
393-
# we already have the model
392+
if hasattr(model_name_or_path, "config") or isinstance(model_name_or_path, PreTrainedModel):
393+
# we already have the model (or a PEFT wrapper with config attribute)
394394
hf_model_config = model_name_or_path.config
395395
else:
396396
# if we don't have the model yet at this stage

tests/unit/ulysses_alst/test_ulysses_sp_hf.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,50 @@ def collate_fn(batch):
185185
torch_assert_close(grad_a, grad_b, rtol=1.6e-02, atol=1e-03)
186186
else:
187187
torch_assert_close(grad_a, grad_b)
188+
189+
190+
class TestUlyssesSPHFPEFT(DistributedTest):
191+
world_size = 2
192+
193+
def test_ulysses_sp_hf_with_peft_model(self):
194+
"""Test that UlyssesSPAttentionHF.register_with_transformers works with PEFT models.
195+
196+
PEFT models don't inherit from transformers.PreTrainedModel but have a config attribute.
197+
This test verifies the duck-typing check for the config attribute works correctly.
198+
"""
199+
model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM'
200+
seq_length = 64
201+
sequence_parallel_size = self.world_size
202+
micro_batch_size = 1
203+
204+
# Create a mock PEFT model object that has config but doesn't inherit from PreTrainedModel
205+
from transformers import AutoConfig
206+
hf_config = AutoConfig.from_pretrained(model_name_or_path)
207+
208+
class MockPEFTModel:
209+
"""Mock PEFT model that simulates PeftModel behavior"""
210+
211+
def __init__(self, config):
212+
self.config = config
213+
214+
mock_peft_model = MockPEFTModel(hf_config)
215+
216+
# Test that register_with_transformers works with PEFT-like model object
217+
# This should not crash and should use the config attribute via duck-typing
218+
mpu = UlyssesSPAttentionHF.register_with_transformers(
219+
model_name_or_path=mock_peft_model,
220+
core_attn_implementation="sdpa",
221+
sequence_parallel_size=sequence_parallel_size,
222+
micro_batch_size=micro_batch_size,
223+
seq_length=seq_length,
224+
seq_length_is_variable=True,
225+
)
226+
227+
# Verify mpu is created successfully
228+
assert mpu is not None
229+
230+
# Verify that the sequence parallel groups are initialized
231+
sp_group = groups._get_sequence_parallel_group()
232+
assert sp_group is not None
233+
sp_world_size = groups._get_sequence_parallel_world_size()
234+
assert sp_world_size == sequence_parallel_size

0 commit comments

Comments
 (0)