Skip to content

Commit 0dfc64d

Browse files
committed
test
1 parent 2ce6344 commit 0dfc64d

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

tests/integrations/test_cloud.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from typing import Optional
66

77
import pytest
8+
import torch
89
from lightning_sdk import Teamspace
910
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
1011
from lightning_sdk.utils.resolve import _resolve_teamspace
1112
from litmodels import download_model, upload_model
1213
from litmodels.integrations.mixins import PickleRegistryMixin
14+
from litmodels.integrations.mixins import PyTorchRegistryMixin
1315

1416
from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1
1517

@@ -239,3 +241,36 @@ def test_pickle_mixin_push_and_pull():
239241

240242
# CLEANING
241243
_cleanup_model(teamspace, model_name, expected_num_versions=1)
244+
245+
246+
class DummyTorchModel(torch.nn.Module, PyTorchRegistryMixin):
247+
def __init__(self, input_size=784):
248+
super().__init__()
249+
self.fc = torch.nn.Linear(input_size, 10)
250+
251+
def forward(self, x):
252+
x = x.view(x.size(0), -1)
253+
return self.fc(x)
254+
255+
256+
@pytest.mark.cloud()
257+
def test_pytorch_mixin_push_and_pull():
258+
# model name with random hash
259+
teamspace, org_team, model_name = _prepare_variables("torch_mixin")
260+
model_registry = f"{org_team}/{model_name}"
261+
262+
# Create an instance, push the model and record its forward output.
263+
dummy = DummyTorchModel(784)
264+
dummy.eval()
265+
input_tensor = torch.randn(1, 784)
266+
output_before = dummy(input_tensor)
267+
268+
dummy.push_to_registry(model_registry)
269+
270+
loaded_dummy = DummyTorchModel.pull_from_registry(model_registry)
271+
loaded_dummy.eval()
272+
output_after = loaded_dummy(input_tensor)
273+
274+
assert isinstance(loaded_dummy, DummyTorchModel)
275+
# Compare the outputs as a verification.
276+
assert torch.allclose(output_before, output_after), "Loaded model output differs from original."

tests/integrations/test_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def forward(self, x):
4343

4444
@mock.patch("litmodels.integrations.mixins.upload_model")
4545
@mock.patch("litmodels.integrations.mixins.download_model")
46-
def test_pytorch_pull_updated(mock_download_model, mock_upload_model, tmp_path):
46+
def test_pytorch_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
4747
# Create an instance, push the model and record its forward output.
4848
dummy = DummyTorchModel(784)
4949
dummy.eval()

0 commit comments

Comments
 (0)