1111from lightning_app .utilities .component import _set_flow_context
1212
1313
14- class SyncWorkLITDriveA (LightningWork ):
14+ class SyncWorkA (LightningWork ):
1515 def __init__ (self , tmpdir ):
1616 super ().__init__ ()
1717 self .tmpdir = tmpdir
@@ -25,35 +25,35 @@ def run(self, drive: Drive):
2525 os .remove (f"{ self .tmpdir } /a.txt" )
2626
2727
28- class SyncWorkLITDriveB (LightningWork ):
28+ class SyncWorkB (LightningWork ):
2929 def run (self , drive : Drive ):
3030 assert not os .path .exists ("a.txt" )
3131 drive .get ("a.txt" )
3232 assert os .path .exists ("a.txt" )
3333
3434
35- class SyncFlowLITDrives (LightningFlow ):
35+ class SyncFlow (LightningFlow ):
3636 def __init__ (self , tmpdir ):
3737 super ().__init__ ()
3838 self .log_dir = Drive ("lit://log_dir" )
39- self .work_a = SyncWorkLITDriveA (str (tmpdir ))
40- self .work_b = SyncWorkLITDriveB ()
39+ self .work_a = SyncWorkA (str (tmpdir ))
40+ self .work_b = SyncWorkB ()
4141
4242 def run (self ):
4343 self .work_a .run (self .log_dir )
4444 self .work_b .run (self .log_dir )
4545 self ._exit ()
4646
4747
48- def test_synchronization_lit_drive (tmpdir ):
48+ def test_synchronization_drive (tmpdir ):
4949 if os .path .exists ("a.txt" ):
5050 os .remove ("a.txt" )
51- app = LightningApp (SyncFlowLITDrives (tmpdir ))
51+ app = LightningApp (SyncFlow (tmpdir ))
5252 MultiProcessRuntime (app , start_server = False ).dispatch ()
5353 os .remove ("a.txt" )
5454
5555
56- class LITDriveWork (LightningWork ):
56+ class Work (LightningWork ):
5757 def __init__ (self ):
5858 super ().__init__ (parallel = True )
5959 self .drive = None
@@ -75,7 +75,7 @@ def run(self, *args, **kwargs):
7575 self .counter += 1
7676
7777
78- class LITDriveWork2 (LightningWork ):
78+ class Work2 (LightningWork ):
7979 def __init__ (self ):
8080 super ().__init__ (parallel = True )
8181
@@ -86,11 +86,11 @@ def run(self, drive: Drive, **kwargs):
8686 assert drive .list ("." , component_name = self .name ) == []
8787
8888
89- class LITDriveFlow (LightningFlow ):
89+ class Flow (LightningFlow ):
9090 def __init__ (self ):
9191 super ().__init__ ()
92- self .work = LITDriveWork ()
93- self .work2 = LITDriveWork2 ()
92+ self .work = Work ()
93+ self .work2 = Work2 ()
9494
9595 def run (self ):
9696 self .work .run ("0" )
@@ -102,15 +102,15 @@ def run(self):
102102 self ._exit ()
103103
104104
105- def test_lit_drive_transferring_files ():
106- app = LightningApp (LITDriveFlow ())
105+ def test_drive_transferring_files ():
106+ app = LightningApp (Flow ())
107107 MultiProcessRuntime (app , start_server = False ).dispatch ()
108108 os .remove ("a.txt" )
109109
110110
111- def test_lit_drive ():
112- with pytest .raises (Exception , match = "Unknown protocol for the drive 'id' argument " ):
113- Drive ("invalid_drive_id " )
111+ def test_drive ():
112+ with pytest .raises (Exception , match = "The Drive id needs to start with one of the following protocols " ):
113+ Drive ("this_drive_id " )
114114
115115 with pytest .raises (
116116 Exception , match = "The id should be unique to identify your drive. Found `this_drive_id/something_else`."
@@ -213,56 +213,19 @@ def test_lit_drive():
213213 os .remove ("a.txt" )
214214
215215
216- def test_s3_drives ():
217- drive = Drive ("s3://foo/" , allow_duplicates = True )
218- drive .component_name = "root.work"
219-
220- with pytest .raises (
221- Exception , match = "S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
222- ):
223- drive .put ("a.txt" )
224- with pytest .raises (
225- Exception ,
226- match = "S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?" ,
227- ):
228- drive .list ("a.txt" )
229- with pytest .raises (
230- Exception , match = "S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
231- ):
232- drive .get ("a.txt" )
233- with pytest .raises (
234- Exception ,
235- match = "S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?" ,
236- ):
237- drive .delete ("a.txt" )
216+ def test_maybe_create_drive ():
238217
239- _set_flow_context ()
240- with pytest .raises (Exception , match = "The flow isn't allowed to put files into a Drive." ):
241- drive .put ("a.txt" )
242- with pytest .raises (Exception , match = "The flow isn't allowed to list files from a Drive." ):
243- drive .list ("a.txt" )
244- with pytest .raises (Exception , match = "The flow isn't allowed to get files from a Drive." ):
245- drive .get ("a.txt" )
246-
247-
248- def test_create_s3_drive_without_trailing_slash_fails ():
249- with pytest .raises (ValueError , match = "S3 drives must end in a trailing slash" ):
250- Drive ("s3://foo" )
251-
252-
253- @pytest .mark .parametrize ("drive_id" , ["lit://drive" , "s3://drive/" ])
254- def test_maybe_create_drive (drive_id ):
255- drive = Drive (drive_id , allow_duplicates = False )
218+ drive = Drive ("lit://drive_3" , allow_duplicates = False )
256219 drive .component_name = "root.work1"
257220 new_drive = _maybe_create_drive (drive .component_name , drive .to_dict ())
258221 assert new_drive .protocol == drive .protocol
259222 assert new_drive .id == drive .id
260223 assert new_drive .component_name == drive .component_name
261224
262225
263- @ pytest . mark . parametrize ( "drive_id" , [ "lit://drive" , "s3://drive/" ])
264- def test_drive_deepcopy ( drive_id ):
265- drive = Drive (drive_id , allow_duplicates = True )
226+ def test_drive_deepcopy ():
227+
228+ drive = Drive ("lit://drive" , allow_duplicates = True )
266229 drive .component_name = "root.work1"
267230 new_drive = deepcopy (drive )
268231 assert new_drive .id == drive .id
0 commit comments