Skip to content

Commit 9db10fe

Browse files
committed
[compat] Expand test suite to full transformers v5 (#3615)
* Expand test suite to full transformers v5 * Specify _endpoint=None in tests as required by newer huggingface_hub * Ensure test_push_to_hub works for older and newer hf_hub
1 parent a1ed1ef commit 9db10fe

File tree

3 files changed

+26
-30
lines changed

3 files changed

+26
-30
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
# The --break-system-packages flag is used to allow pip to upgrade or install packages
7676
# even in environments where system-managed packages might otherwise block such operations.
7777
python -m pip install --upgrade pip --break-system-packages
78-
python -m pip install '.[train, dev]' 'transformers>=5.0.0rc0'
78+
python -m pip install '.[train, dev]' 'transformers>=5.0.0'
7979
8080
- name: Install model2vec
8181
run: python -m pip install model2vec

tests/cross_encoder/test_cross_encoder.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,19 @@ def mock_create_repo(self, repo_id, **kwargs):
275275
def mock_upload_folder(self, **kwargs):
276276
nonlocal mock_upload_folder_kwargs
277277
mock_upload_folder_kwargs = kwargs
278-
if kwargs.get("revision") is None:
279-
revision = "123456"
280-
else:
281-
revision = "678901"
282-
return CommitInfo(
283-
commit_url=f"https://huggingface.co/{kwargs.get('repo_id')}/commit/{revision}",
284-
commit_message="commit_message",
285-
commit_description="commit_description",
286-
oid="oid",
287-
pr_url=f"https://huggingface.co/{kwargs.get('repo_id')}/discussions/123",
288-
)
278+
commit_hash = "123456" if kwargs.get("revision") is None else "678901"
279+
commit_info_kwargs = {
280+
"commit_url": f"https://huggingface.co/{kwargs.get('repo_id')}/commit/{commit_hash}",
281+
"commit_message": "commit_message",
282+
"commit_description": "commit_description",
283+
"oid": "oid",
284+
"pr_url": f"https://huggingface.co/{kwargs.get('repo_id')}/discussions/123",
285+
}
286+
try:
287+
return CommitInfo(**commit_info_kwargs)
288+
except TypeError:
289+
# Required as of https://github.com/huggingface/huggingface_hub/pull/3679
290+
return CommitInfo(**commit_info_kwargs, _endpoint=None)
289291

290292
def mock_create_branch(self, repo_id, branch, revision=None, **kwargs):
291293
return None

tests/test_sentence_transformer.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,20 +120,18 @@ def mock_create_repo(self, repo_id, **kwargs):
120120
def mock_upload_folder(self, **kwargs):
121121
nonlocal mock_upload_folder_kwargs
122122
mock_upload_folder_kwargs = kwargs
123-
if kwargs.get("revision") is None:
124-
return CommitInfo(
125-
commit_url=f"https://huggingface.co/{kwargs.get('repo_id')}/commit/123456",
126-
commit_message="commit_message",
127-
commit_description="commit_description",
128-
oid="oid",
129-
)
130-
else:
131-
return CommitInfo(
132-
commit_url=f"https://huggingface.co/{kwargs.get('repo_id')}/commit/678901",
133-
commit_message="commit_message",
134-
commit_description="commit_description",
135-
oid="oid",
136-
)
123+
commit_hash = "123456" if kwargs.get("revision") is None else "678901"
124+
commit_info_kwargs = {
125+
"commit_url": f"https://huggingface.co/{kwargs.get('repo_id')}/commit/{commit_hash}",
126+
"commit_message": "commit_message",
127+
"commit_description": "commit_description",
128+
"oid": "oid",
129+
}
130+
try:
131+
return CommitInfo(**commit_info_kwargs)
132+
except TypeError:
133+
# Required as of https://github.com/huggingface/huggingface_hub/pull/3679
134+
return CommitInfo(**commit_info_kwargs, _endpoint=None)
137135

138136
def mock_create_branch(self, repo_id, branch, revision=None, **kwargs):
139137
return None
@@ -473,10 +471,6 @@ def test_prompt_length_calculation(
473471
assert model._prompt_length_mapping == {("Prompt: ", "query"): only_prompt_length}
474472

475473

476-
@pytest.mark.skipif(
477-
parse(transformers_version) == Version("5.0.0rc01"),
478-
reason="Transformers 5.0.0rc01 has a bug with saving models modified with model.half().",
479-
)
480474
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
481475
def test_load_with_torch_dtype(stsb_bert_tiny_model: SentenceTransformer) -> None:
482476
model = stsb_bert_tiny_model

0 commit comments

Comments
 (0)