8
8
from logging import getLogger
9
9
from pathlib import Path
10
10
from typing import Dict , List , Optional , Set
11
+ import shutil
11
12
12
13
from anyio import Path as AsyncPath , create_task_group
13
14
import sqlalchemy
@@ -98,7 +99,7 @@ def __init__(
98
99
self ,
99
100
* ,
100
101
_sql_engine : sqlalchemy .engine .Engine ,
101
- _sources_by_id : Dict [str , ProtocolSource ],
102
+ _sources_by_id : dict [str , ProtocolSource | _BadProtocolSource ],
102
103
) -> None :
103
104
"""Do not call directly.
104
105
@@ -117,8 +118,7 @@ def create_empty(
117
118
Params:
118
119
sql_engine: A reference to the database that this ProtocolStore should
119
120
use as its backing storage.
120
- This is expected to already have the proper tables set up;
121
- see `add_tables_to_db()`.
121
+ This is expected to already have the proper tables set up.
122
122
This should have no protocol data currently stored.
123
123
If there is data, use `rehydrate()` instead.
124
124
"""
@@ -141,8 +141,7 @@ async def rehydrate(
141
141
Params:
142
142
sql_engine: A reference to the database that this ProtocolStore should
143
143
use as its backing storage.
144
- This is expected to already have the proper tables set up;
145
- see `add_tables_to_db()`.
144
+ This is expected to already have the proper tables set up.
146
145
protocols_directory: Where to look for protocol files while rehydrating.
147
146
This is expected to have one subdirectory per protocol,
148
147
named after its protocol ID.
@@ -157,7 +156,7 @@ async def rehydrate(
157
156
158
157
sources_by_id = await _compute_protocol_sources (
159
158
expected_protocol_ids = expected_ids ,
160
- protocols_directory = AsyncPath ( protocols_directory ) ,
159
+ protocols_directory = protocols_directory ,
161
160
protocol_reader = protocol_reader ,
162
161
)
163
162
@@ -171,16 +170,18 @@ def insert(self, resource: ProtocolResource) -> None:
171
170
172
171
The resource must have a unique ID.
173
172
"""
174
- self ._sql_insert (
175
- resource = _DBProtocolResource (
176
- protocol_id = resource .protocol_id ,
177
- created_at = resource .created_at ,
178
- protocol_key = resource .protocol_key ,
179
- protocol_kind = _http_protocol_kind_to_sql (resource .protocol_kind ),
173
+ try :
174
+ self ._sql_insert (
175
+ resource = _DBProtocolResource (
176
+ protocol_id = resource .protocol_id ,
177
+ created_at = resource .created_at ,
178
+ protocol_key = resource .protocol_key ,
179
+ protocol_kind = _http_protocol_kind_to_sql (resource .protocol_kind ),
180
+ )
180
181
)
181
- )
182
- self . _sources_by_id [ resource . protocol_id ] = resource . source
183
- self ._clear_caches ()
182
+ self . _sources_by_id [ resource . protocol_id ] = resource . source
183
+ finally :
184
+ self ._clear_caches ()
184
185
185
186
@lru_cache (maxsize = _CACHE_ENTRIES )
186
187
def get (self , protocol_id : str ) -> ProtocolResource :
@@ -190,30 +191,48 @@ def get(self, protocol_id: str) -> ProtocolResource:
190
191
ProtocolNotFoundError
191
192
"""
192
193
sql_resource = self ._sql_get (protocol_id = protocol_id )
193
- return ProtocolResource (
194
- protocol_id = sql_resource .protocol_id ,
195
- created_at = sql_resource .created_at ,
196
- protocol_key = sql_resource .protocol_key ,
197
- protocol_kind = _sql_protocol_kind_to_http (sql_resource .protocol_kind ),
198
- source = self ._sources_by_id [sql_resource .protocol_id ],
199
- )
194
+ protocol_source = self ._sources_by_id [sql_resource .protocol_id ]
195
+ match protocol_source :
196
+ case ProtocolSource () as protocol_source :
197
+ return ProtocolResource (
198
+ protocol_id = sql_resource .protocol_id ,
199
+ created_at = sql_resource .created_at ,
200
+ protocol_key = sql_resource .protocol_key ,
201
+ protocol_kind = _sql_protocol_kind_to_http (
202
+ sql_resource .protocol_kind
203
+ ),
204
+ source = protocol_source ,
205
+ )
206
+ case _BadProtocolSource (reason = reason ):
207
+ raise reason
200
208
201
209
@lru_cache (maxsize = _CACHE_ENTRIES )
202
210
def get_all (self ) -> List [ProtocolResource ]:
203
211
"""Get all protocols currently saved in this store.
204
212
205
213
Results are ordered from first-added to last-added.
214
+
215
+ If there was an error processing a protocol, it's excluded from the returned
216
+ list. This can happen, for example, if a software downgrade left the robot with
217
+ protocol files that are too new for the software that it's running now.
206
218
"""
207
219
all_sql_resources = self ._sql_get_all ()
220
+ all_sql_resources_and_protocol_sources = (
221
+ (r , self ._sources_by_id [r .protocol_id ]) for r in all_sql_resources
222
+ )
208
223
return [
209
224
ProtocolResource (
210
- protocol_id = r .protocol_id ,
211
- created_at = r .created_at ,
212
- protocol_key = r .protocol_key ,
213
- protocol_kind = _sql_protocol_kind_to_http (r .protocol_kind ),
214
- source = self . _sources_by_id [ r . protocol_id ] ,
225
+ protocol_id = sql_resource .protocol_id ,
226
+ created_at = sql_resource .created_at ,
227
+ protocol_key = sql_resource .protocol_key ,
228
+ protocol_kind = _sql_protocol_kind_to_http (sql_resource .protocol_kind ),
229
+ source = protocol_source ,
215
230
)
216
- for r in all_sql_resources
231
+ for (
232
+ sql_resource ,
233
+ protocol_source ,
234
+ ) in all_sql_resources_and_protocol_sources
235
+ if not isinstance (protocol_source , _BadProtocolSource )
217
236
]
218
237
219
238
@lru_cache (maxsize = _CACHE_ENTRIES )
@@ -258,17 +277,20 @@ def remove(self, protocol_id: str) -> None:
258
277
ProtocolUsedByRunError: the protocol could not be deleted because
259
278
there is a run currently referencing the protocol.
260
279
"""
261
- self ._sql_remove (protocol_id = protocol_id )
262
-
263
- deleted_source = self ._sources_by_id .pop (protocol_id )
264
- protocol_dir = deleted_source .directory
265
-
266
- for source_file in deleted_source .files :
267
- source_file .path .unlink ()
268
- if protocol_dir :
269
- protocol_dir .rmdir ()
270
-
271
- self ._clear_caches ()
280
+ try :
281
+ self ._sql_remove (protocol_id = protocol_id )
282
+
283
+ deleted_source = self ._sources_by_id .pop (protocol_id )
284
+ match deleted_source :
285
+ case ProtocolSource (directory = directory , files = files ):
286
+ for source_file in files :
287
+ source_file .path .unlink ()
288
+ if directory :
289
+ directory .rmdir ()
290
+ case _BadProtocolSource (directory = directory ):
291
+ shutil .rmtree (directory , ignore_errors = True )
292
+ finally :
293
+ self ._clear_caches ()
272
294
273
295
# Note that this is NOT cached like the other getters because we would need
274
296
# to invalidate the cache whenever the runs table changes, which is not something
@@ -448,18 +470,11 @@ def _clear_caches(self) -> None:
448
470
self .has .cache_clear ()
449
471
450
472
451
- # TODO(mm, 2022-04-18):
452
- # Restructure to degrade gracefully in the face of ProtocolReader failures.
453
- #
454
- # * ProtocolStore.get_all() should omit protocols for which it failed to compute
455
- # a ProtocolSource.
456
- # * ProtocolStore.get(id) should continue to raise an exception if it failed to compute
457
- # that protocol's ProtocolSource.
458
473
async def _compute_protocol_sources (
459
474
expected_protocol_ids : Set [str ],
460
- protocols_directory : AsyncPath ,
475
+ protocols_directory : Path ,
461
476
protocol_reader : ProtocolReader ,
462
- ) -> Dict [str , ProtocolSource ]:
477
+ ) -> dict [str , ProtocolSource | _BadProtocolSource ]:
463
478
"""Compute `ProtocolSource` objects from protocol source files.
464
479
465
480
We don't store these `ProtocolSource` objects in the SQL database because
@@ -475,19 +490,19 @@ async def _compute_protocol_sources(
475
490
protocol_reader: An interface to use to compute `ProtocolSource`s.
476
491
477
492
Returns:
478
- A map from protocol ID to computed `ProtocolSource`.
493
+ A map from protocol ID to computed `ProtocolSource`, or an `Exception` if
494
+ there was a problem processing that particular protocol.
479
495
480
496
Raises:
481
497
Exception: This is not expected to raise anything,
482
498
but it might if a software update makes ProtocolReader reject files
483
499
that it formerly accepted.
484
500
"""
485
- sources_by_id : Dict [str , ProtocolSource ] = {}
501
+ sources_by_id : dict [str , ProtocolSource | _BadProtocolSource ] = {}
486
502
487
- directory_members = [m async for m in protocols_directory .iterdir ()]
503
+ directory_members = [m async for m in AsyncPath ( protocols_directory ) .iterdir ()]
488
504
directory_member_names = set (m .name for m in directory_members )
489
505
extra_members = directory_member_names - expected_protocol_ids
490
- missing_members = expected_protocol_ids - directory_member_names
491
506
492
507
if extra_members :
493
508
# Extra members may be left over from prior interrupted writes
@@ -498,38 +513,48 @@ async def _compute_protocol_sources(
498
513
f" Ignoring them."
499
514
)
500
515
501
- if missing_members :
502
- raise SubdirectoryMissingError (
503
- f"Missing subdirectories for protocols: { missing_members } "
504
- )
505
-
506
516
async def compute_source (
507
- protocol_id : str , protocol_subdirectory : AsyncPath
517
+ protocol_subdirectory : Path ,
518
+ ) -> ProtocolSource | _BadProtocolSource :
519
+ try :
520
+ # Given that the expected protocol subdirectory exists,
521
+ # we trust that the files in it are correct.
522
+ # No extra files, and no files missing.
523
+ #
524
+ # This is a safe assumption as long as:
525
+ # * Nobody has tampered with file the storage.
526
+ # * We don't try to compute the source of any protocol whose insertion
527
+ # failed halfway through and left files behind.
528
+ protocol_files = [
529
+ Path (f ) async for f in AsyncPath (protocol_subdirectory ).iterdir ()
530
+ ]
531
+ protocol_source = await protocol_reader .read_saved (
532
+ files = protocol_files ,
533
+ directory = Path (protocol_subdirectory ),
534
+ files_are_prevalidated = True ,
535
+ python_parse_mode = PythonParseMode .ALLOW_LEGACY_METADATA_AND_REQUIREMENTS ,
536
+ )
537
+ return protocol_source
538
+ except Exception as exception :
539
+ # e.g. if a software downgrade left the robot with some protocol files that
540
+ # are too new for the software version that it's running now.
541
+ _log .exception (f"Error reading protocol in { protocol_subdirectory } ." )
542
+ return _BadProtocolSource (directory = protocol_subdirectory , reason = exception )
543
+
544
+ async def compute_source_and_store_in_result_dict (
545
+ protocol_id : str , protocol_subdirectory : Path
508
546
) -> None :
509
- # Given that the expected protocol subdirectory exists,
510
- # we trust that the files in it are correct.
511
- # No extra files, and no files missing.
512
- #
513
- # This is a safe assumption as long as:
514
- # * Nobody has tampered with file the storage.
515
- # * We don't try to compute the source of any protocol whose insertion
516
- # failed halfway through and left files behind.
517
- protocol_files = [Path (f ) async for f in protocol_subdirectory .iterdir ()]
518
- protocol_source = await protocol_reader .read_saved (
519
- files = protocol_files ,
520
- directory = Path (protocol_subdirectory ),
521
- files_are_prevalidated = True ,
522
- python_parse_mode = PythonParseMode .ALLOW_LEGACY_METADATA_AND_REQUIREMENTS ,
523
- )
524
- sources_by_id [protocol_id ] = protocol_source
547
+ result = await compute_source (protocol_subdirectory )
548
+ sources_by_id [protocol_id ] = result
525
549
526
550
async with create_task_group () as task_group :
527
- # Use a TaskGroup instead of asyncio.gather() so,
528
- # if any task raises an unexpected exception,
529
- # it cancels every other task and raises an exception to signal the bug.
530
551
for protocol_id in expected_protocol_ids :
531
552
protocol_subdirectory = protocols_directory / protocol_id
532
- task_group .start_soon (compute_source , protocol_id , protocol_subdirectory )
553
+ task_group .start_soon (
554
+ compute_source_and_store_in_result_dict ,
555
+ protocol_id ,
556
+ protocol_subdirectory ,
557
+ )
533
558
534
559
for id in expected_protocol_ids :
535
560
assert id in sources_by_id
@@ -547,6 +572,14 @@ class _DBProtocolResource:
547
572
protocol_kind : ProtocolKindSQLEnum
548
573
549
574
575
+ @dataclass (frozen = True )
576
+ class _BadProtocolSource :
577
+ """Information about files that we failed to process into a ProtocolSource."""
578
+
579
+ directory : Path
580
+ reason : Exception
581
+
582
+
550
583
def _convert_sql_row_to_dataclass (
551
584
sql_row : sqlalchemy .engine .Row ,
552
585
) -> _DBProtocolResource :
0 commit comments