|
8 | 8 | from unittest.mock import Mock |
9 | 9 |
|
10 | 10 | import pytest |
| 11 | +from huggingface_hub import CommitOperationAdd, CommitOperationDelete |
11 | 12 |
|
12 | 13 | from kernels.cli import upload_kernels |
| 14 | +from kernels.cli.upload import BUILD_COMMIT_BATCH_SIZE |
13 | 15 | from kernels.utils import _get_hf_api |
14 | 16 |
|
15 | 17 | REPO_ID = "valid_org/kernels-upload-test" |
@@ -172,25 +174,62 @@ def test_kernel_upload_deletes_as_expected(): |
172 | 174 | _get_hf_api().delete_repo(repo_id=REPO_ID) |
173 | 175 |
|
174 | 176 |
|
175 | | -def test_large_kernel_upload_uses_kernel_root_path(monkeypatch, tmp_path): |
| 177 | +def test_large_kernel_upload_uses_create_commit_batches(monkeypatch, tmp_path): |
176 | 178 | kernel_root = tmp_path / "kernel" |
177 | 179 | build_variant = kernel_root / "build" / "torch-cpu" |
178 | 180 | build_variant.mkdir(parents=True, exist_ok=True) |
179 | 181 | (build_variant / "metadata.json").write_text("{}") |
180 | | - for i in range(1001): |
| 182 | + file_count = BUILD_COMMIT_BATCH_SIZE * 2 |
| 183 | + for i in range(file_count): |
181 | 184 | (build_variant / f"file_{i}.py").touch() |
182 | 185 |
|
183 | 186 | api = Mock() |
184 | 187 | api.create_repo.return_value = SimpleNamespace(repo_id=REPO_ID) |
| 188 | + api.list_repo_files.return_value = [ |
| 189 | + "README.md", |
| 190 | + "build/torch-cpu/file_0.py", |
| 191 | + "build/torch-cpu/stale.py", |
| 192 | + "build/torch-cuda/keep.py", |
| 193 | + ] |
185 | 194 | monkeypatch.setattr("kernels.cli.upload._get_hf_api", lambda: api) |
186 | 195 |
|
187 | 196 | upload_kernels(UploadArgs(kernel_root, REPO_ID, False, "main")) |
188 | 197 |
|
189 | | - api.upload_large_folder.assert_called_once() |
190 | | - kwargs = api.upload_large_folder.call_args.kwargs |
191 | | - assert kwargs["repo_id"] == REPO_ID |
192 | | - assert kwargs["folder_path"] == kernel_root.resolve() |
193 | | - assert kwargs["revision"] == "main" |
194 | | - assert kwargs["repo_type"] == "model" |
195 | | - assert kwargs["allow_patterns"] == ["build/torch*"] |
| 198 | + # 2 full batches of adds, plus metadata and 1 stale-file delete. |
| 199 | + assert api.create_commit.call_count == 3 |
| 200 | + batch_sizes = [ |
| 201 | + len(call.kwargs["operations"]) for call in api.create_commit.call_args_list |
| 202 | + ] |
| 203 | + assert batch_sizes == [ |
| 204 | + BUILD_COMMIT_BATCH_SIZE, |
| 205 | + BUILD_COMMIT_BATCH_SIZE, |
| 206 | + 2, |
| 207 | + ] |
| 208 | + commit_messages = [ |
| 209 | + call.kwargs["commit_message"] for call in api.create_commit.call_args_list |
| 210 | + ] |
| 211 | + assert commit_messages == [ |
| 212 | + "Build uploaded using `kernels` (batch 1/3).", |
| 213 | + "Build uploaded using `kernels` (batch 2/3).", |
| 214 | + "Build uploaded using `kernels` (batch 3/3).", |
| 215 | + ] |
| 216 | + |
| 217 | + # Stale repo files should be deleted. |
| 218 | + operations = [ |
| 219 | + operation |
| 220 | + for call in api.create_commit.call_args_list |
| 221 | + for operation in call.kwargs["operations"] |
| 222 | + ] |
| 223 | + delete_paths = { |
| 224 | + op.path_in_repo for op in operations if isinstance(op, CommitOperationDelete) |
| 225 | + } |
| 226 | + assert delete_paths == {"build/torch-cpu/stale.py"} |
| 227 | + |
| 228 | + add_paths = { |
| 229 | + op.path_in_repo for op in operations if isinstance(op, CommitOperationAdd) |
| 230 | + } |
| 231 | + assert len(add_paths) == file_count + 1 |
| 232 | + assert "build/torch-cpu/metadata.json" in add_paths |
| 233 | + assert "build/torch-cpu/file_0.py" in add_paths |
| 234 | + assert "build/torch-cpu/file_399.py" in add_paths |
196 | 235 | api.upload_folder.assert_not_called() |
0 commit comments