99from lightning_sdk import Teamspace
1010from lightning_sdk .lightning_cloud .rest_client import GridRestClient
1111from lightning_sdk .utils .resolve import _resolve_teamspace
12+
1213from litmodels import download_model , upload_model
1314from litmodels .integrations .duplicate import duplicate_hf_model
1415from litmodels .integrations .mixins import PickleRegistryMixin , PyTorchRegistryMixin
1516from litmodels .io .cloud import _list_available_teamspaces
16-
1717from tests .integrations import (
1818 _SKIP_IF_LIGHTNING_BELLOW_2_5_1 ,
1919 _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 ,
@@ -44,7 +44,7 @@ def _cleanup_model(teamspace: Teamspace, model_name: str, expected_num_versions:
4444 client .models_store_delete_model (project_id = model .project_id , model_id = model .id )
4545
4646
47- @pytest .mark .cloud ()
47+ @pytest .mark .cloud
4848@pytest .mark .parametrize (
4949 "in_studio" ,
5050 [False , pytest .param (True , marks = pytest .mark .skipif (platform .system () != "Linux" , reason = "Studio is just Linux" ))],
@@ -100,7 +100,7 @@ def test_upload_download_model(in_studio, monkeypatch, tmp_path):
100100 pytest .param (True , marks = pytest .mark .skipif (platform .system () == "Windows" , reason = "studio is not Windows" )),
101101 ],
102102)
103- @pytest .mark .cloud ()
103+ @pytest .mark .cloud
104104def test_lightning_default_checkpointing (importing , in_studio , monkeypatch , tmp_path ):
105105 if in_studio :
106106 # mock env variables as it would run in studio
@@ -140,7 +140,7 @@ def test_lightning_default_checkpointing(importing, in_studio, monkeypatch, tmp_
140140 pytest .param ("pytorch_lightning" , marks = _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 ),
141141 ],
142142)
143- @pytest .mark .cloud ()
143+ @pytest .mark .cloud
144144# todo: mock env variables as it would run in studio
145145def test_lightning_plain_resume (trainer_method , registry , importing , tmp_path ):
146146 if importing == "lightning" :
@@ -194,7 +194,7 @@ def test_lightning_plain_resume(trainer_method, registry, importing, tmp_path):
194194 pytest .param ("pytorch_lightning" , marks = _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 ),
195195 ],
196196)
197- @pytest .mark .cloud ()
197+ @pytest .mark .cloud
198198def test_lightning_checkpoint_ddp (importing , tmp_path ):
199199 if importing == "lightning" :
200200 from lightning import Trainer
@@ -229,7 +229,7 @@ def __init__(self, value):
229229 self .value = value
230230
231231
232- @pytest .mark .cloud ()
232+ @pytest .mark .cloud
233233def test_pickle_mixin_push_and_pull ():
234234 # model name with random hash
235235 teamspace , org_team , model_name = _prepare_variables ("pickle_mixin" )
@@ -263,7 +263,7 @@ def forward(self, x):
263263 return self .fc (x )
264264
265265
266- @pytest .mark .cloud ()
266+ @pytest .mark .cloud
267267def test_pytorch_mixin_push_and_pull ():
268268 # model name with random hash
269269 teamspace , org_team , model_name = _prepare_variables ("torch_mixin" )
@@ -289,7 +289,7 @@ def test_pytorch_mixin_push_and_pull():
289289 _cleanup_model (teamspace , model_name , expected_num_versions = 1 )
290290
291291
292- @pytest .mark .cloud ()
292+ @pytest .mark .cloud
293293def test_duplicate_real_hf_model (tmp_path ):
294294 """Verify that the HF model can be duplicated to the teamspace"""
295295
@@ -309,7 +309,7 @@ def test_duplicate_real_hf_model(tmp_path):
309309 client .models_store_delete_model (project_id = teamspace .id , model_id = model .id )
310310
311311
312- @pytest .mark .cloud ()
312+ @pytest .mark .cloud
313313def test_list_available_teamspaces ():
314314 teams = _list_available_teamspaces ()
315315 assert len (teams ) > 0
0 commit comments