|
5 | 5 | from typing import Optional |
6 | 6 |
|
7 | 7 | import pytest |
| 8 | +import torch |
8 | 9 | from lightning_sdk import Teamspace |
9 | 10 | from lightning_sdk.lightning_cloud.rest_client import GridRestClient |
10 | 11 | from lightning_sdk.utils.resolve import _resolve_teamspace |
11 | 12 | from litmodels import download_model, upload_model |
12 | 13 | from litmodels.integrations.mixins import PickleRegistryMixin |
| 14 | +from litmodels.integrations.mixins import PyTorchRegistryMixin |
13 | 15 |
|
14 | 16 | from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 |
15 | 17 |
|
@@ -239,3 +241,36 @@ def test_pickle_mixin_push_and_pull(): |
239 | 241 |
|
240 | 242 | # CLEANING |
241 | 243 | _cleanup_model(teamspace, model_name, expected_num_versions=1) |
| 244 | + |
| 245 | + |
| 246 | +class DummyTorchModel(torch.nn.Module, PyTorchRegistryMixin): |
| 247 | + def __init__(self, input_size=784): |
| 248 | + super().__init__() |
| 249 | + self.fc = torch.nn.Linear(input_size, 10) |
| 250 | + |
| 251 | + def forward(self, x): |
| 252 | + x = x.view(x.size(0), -1) |
| 253 | + return self.fc(x) |
| 254 | + |
| 255 | + |
| 256 | +@pytest.mark.cloud() |
| 257 | +def test_pytorch_mixin_push_and_pull(): |
| 258 | + # model name with random hash |
| 259 | + teamspace, org_team, model_name = _prepare_variables("torch_mixin") |
| 260 | + model_registry = f"{org_team}/{model_name}" |
| 261 | + |
| 262 | + # Create an instance, push the model and record its forward output. |
| 263 | + dummy = DummyTorchModel(784) |
| 264 | + dummy.eval() |
| 265 | + input_tensor = torch.randn(1, 784) |
| 266 | + output_before = dummy(input_tensor) |
| 267 | + |
| 268 | + dummy.push_to_registry(model_registry) |
| 269 | + |
| 270 | + loaded_dummy = DummyTorchModel.pull_from_registry(model_registry) |
| 271 | + loaded_dummy.eval() |
| 272 | + output_after = loaded_dummy(input_tensor) |
| 273 | + |
| 274 | + assert isinstance(loaded_dummy, DummyTorchModel) |
| 275 | + # Compare the outputs as a verification. |
| 276 | + assert torch.allclose(output_before, output_after), "Loaded model output differs from original." |
0 commit comments