1- import pathlib
2- import contextlib
31import json
4- from struct import Struct
5- import struct
2+ import operator
63
74from math import floor
85from datetime import datetime
9- from typing_extensions import Generator , Optional
6+ from typing_extensions import Optional
107
118import tiledb
12- from tiledb .metadata import Metadata
139import numpy as np
1410import pandas as pd
1511
@@ -25,7 +21,6 @@ def ts_overlap(first, second):
2521 return False
2622 return True
2723
28-
2924class Storage :
3025 """Handles storage of shattered data in a TileDB Database."""
3126
@@ -35,14 +30,16 @@ def __init__(self, config: StorageConfig):
3530 f"Given database directory '{ config .tdb_dir } ' does not exist"
3631 )
3732
38- self .config = config
39- self .reader : tiledb .SparseArray = None
40- self .writer : tiledb .SparseArray = None
33+ self .config : StorageConfig = config
34+ self ._reader : tiledb .SparseArray = None
4135
4236 def __enter__ (self ):
4337 return self
4438
4539 def __exit__ (self , exc_type , exc_value , exc_tb ):
40+ if self ._reader is not None :
41+ self ._reader .close ()
42+ self ._reader = None
4643 return
4744
4845 @staticmethod
@@ -87,25 +84,25 @@ def create(config: StorageConfig, ctx: tiledb.Ctx = None):
8784 name = 'X' ,
8885 domain = (0 , xi ),
8986 dtype = np .int32 ,
90- filters = tiledb .FilterList ([ tiledb .ZstdFilter () ]),
87+ filters = tiledb .FilterList ([tiledb .ZstdFilter ()]),
9188 )
9289 dim_col = tiledb .Dim (
9390 name = 'Y' ,
9491 domain = (0 , yi ),
9592 dtype = np .int32 ,
96- filters = tiledb .FilterList ([ tiledb .ZstdFilter () ]),
93+ filters = tiledb .FilterList ([tiledb .ZstdFilter ()]),
9794 )
9895 domain = tiledb .Domain (dim_row , dim_col )
9996
10097 count_att = tiledb .Attr (
10198 name = 'count' ,
10299 dtype = np .int32 ,
103- filters = tiledb .FilterList ([ tiledb .ZstdFilter () ]),
100+ filters = tiledb .FilterList ([tiledb .ZstdFilter ()]),
104101 )
105102 proc_att = tiledb .Attr (
106103 name = 'shatter_process_num' ,
107104 dtype = np .uint64 ,
108- filters = tiledb .FilterList ([ tiledb .ZstdFilter () ]),
105+ filters = tiledb .FilterList ([tiledb .ZstdFilter ()]),
109106 )
110107 dim_atts = [attr .schema () for attr in config .attrs ]
111108
@@ -152,11 +149,10 @@ def create(config: StorageConfig, ctx: tiledb.Ctx = None):
152149 schema .check ()
153150
154151 tiledb .SparseArray .create (config .tdb_dir , schema )
155- writer = tiledb .SparseArray (config .tdb_dir , 'w' )
156- writer .meta ['config' ] = str (config )
152+ with tiledb .SparseArray (config .tdb_dir , 'w' ) as writer :
153+ writer .meta ['config' ] = str (config )
157154
158155 s = Storage (config )
159- s .writer = writer
160156 s .save_config ()
161157
162158 return s
@@ -182,8 +178,7 @@ def from_db(tdb_dir: str, ctx: tiledb.Ctx = None):
182178 storage = Storage (config )
183179
184180 # set the metadata for storage object so we don't have to query again
185- storage ._meta = metadata
186- storage .reader = reader
181+ storage ._reader = reader
187182 storage .save_config ()
188183
189184 return storage
@@ -196,8 +191,8 @@ def save_config(self) -> None:
196191 # key later
197192 with self .open ('w' ) as w :
198193 w .meta ['config' ] = str (self .config )
199- with self .open ( 'r' ) as r :
200- self ._meta = r . meta
194+ if self ._reader is not None :
195+ self ._reader . reopen ()
201196
202197 def get_config (self ) -> StorageConfig :
203198 """
@@ -236,20 +231,13 @@ def get_metadata(self, key: str) -> str:
236231 :return: Metadata value found in storage.
237232 """
238233 # if meta hasn't been set up, do so
239- if self ._meta is None :
240- with self .open ('r' ) as tdb :
241- self ._meta = tdb .meta
242- # this should be latest metadata, handle a keyerror higher up
243- return self ._meta [key ]
244-
245- # if it was already set up and the key doesn't exist,
246- # reload metadata and check there
247- if key not in self ._meta .keys ():
248- with self .open ('r' ) as tdb :
249- self ._meta = tdb .meta
250-
251- return self ._meta ['key' ]
252-
234+ reader = self .open ('r' )
235+ try :
236+ return reader .meta [key ]
237+ except KeyError :
238+ reader .reopen ()
239+ self ._reader = reader
240+ return reader .meta [key ]
253241
254242 def save_metadata (self , key : str , data : str ) -> None :
255243 """
@@ -259,12 +247,11 @@ def save_metadata(self, key: str, data: str) -> None:
259247 :param data: Data to save to metadata.
260248 """
261249 # if writer isn't set up, do it now
262- if self ._writer is None :
263- self ._writer = self .open ('w' )
264-
265250 # propogate the key-value to both tiledb and the local copy
266- self ._meta [key ] = data
267- self ._writer .meta [key ] = data
251+ with self .open ('w' ) as w :
252+ w .meta [key ] = data
253+ if self ._reader is not None :
254+ self ._reader .reopen ()
268255
269256 def get_tdb_context (self ):
270257 cfg = tiledb .Config ()
@@ -299,9 +286,7 @@ def get_derived_names(self) -> list[str]:
299286 if not m .attributes or a .name in [ma .name for ma in m .attributes ]
300287 ]
301288
302- def open (
303- self , mode : str = 'r' , timestamp = None
304- ) -> Generator [tiledb .SparseArray , None , None ]:
289+ def open (self , mode : str = 'r' , timestamp = None ) -> tiledb .SparseArray :
305290 """
306291 Open stream for TileDB database in given mode and at given timestamp.
307292
@@ -318,22 +303,18 @@ def open(
318303 # other threads present
319304 ctx = self .get_tdb_context ()
320305
321- if tiledb .object_type (self .config .tdb_dir ) == 'array' :
322- if mode in ['w' , 'r' , 'd' , 'm' ]:
323- tdb = tiledb .open (
324- self .config .tdb_dir , mode , timestamp = timestamp , ctx = ctx
325- )
326- else :
327- raise Exception (f"Given open mode '{ mode } ' is not valid" )
328- elif pathlib .Path (self .config .tdb_dir ).exists ():
329- raise Exception (
330- f'Path { self .config .tdb_dir } already exists and is not'
331- ' initialized for TileDB access.'
306+ # non-timestamped reader and writer are stored as member variables to
307+ # avoid opening and closing too many io objects.
308+ if timestamp is not None or mode != 'r' :
309+ return tiledb .open (
310+ self .config .tdb_dir , mode , timestamp = timestamp , ctx = ctx
332311 )
333- else :
334- raise Exception (f'Path { self .config .tdb_dir } does not exist' )
312+ else : # no timestamp and mode is 'r'
313+ if self ._reader is None or not self ._reader .isopen :
314+ self ._reader = tiledb .open (self .config .tdb_dir , 'r' )
335315
336- return tdb
316+ self ._reader .reopen ()
317+ return self ._reader
337318
338319 def write (self , data_in : pd .DataFrame , timestamp ):
339320 data_in = data_in .rename (columns = {'xi' : 'X' , 'yi' : 'Y' })
@@ -375,17 +356,17 @@ def reserve_time_slot(self) -> int:
375356 shatter process.
376357
377358 :param config: Shatter config will be written as metadata to reserve
378- time slot.
359+ time slot.
379360
380361 :return: Time slot.
381362 """
382- latest = self .get_metadata ('config' )
383-
384- time = latest ['next_time_slot' ]
385- latest ['next_time_slot' ] = time + 1
386- self .save_metadata ('config' , json .dumps (latest ))
363+ # make sure we're dealing with the latest config
364+ cfg = self .get_config ()
365+ self .config = cfg
366+ time = self .config .next_time_slot
367+ self .config .next_time_slot = time + 1
368+ self .save_config ()
387369
388- self .config .next_time_slot = latest ['next_time_slot' ]
389370 return time
390371
391372 def get_history (
@@ -472,14 +453,16 @@ def delete(self, time_slot: int) -> ShatterConfig:
472453 """
473454
474455 self .config .log .debug (f'Deleting time slot { time_slot } ...' )
475- with self .open ('r' ) as r :
476- sh_cfg = ShatterConfig .from_string (r .meta [f'shatter_{ time_slot } ' ])
477- sh_cfg .mbr = ()
478- sh_cfg .finished = False
456+
457+ r = self .open ('r' )
458+ sh_cfg = ShatterConfig .from_string (r .meta [f'shatter_{ time_slot } ' ])
459+ sh_cfg .mbr = ()
460+ sh_cfg .finished = False
479461
480462 self .config .log .debug ('Deleting fragments...' )
481- with self .open ('d' ) as d :
482- d .query (cond = f'shatter_process_num=={ time_slot } ' ).submit ()
463+ d = self .open ('d' )
464+ d .query (cond = f'shatter_process_num=={ time_slot } ' ).submit ()
465+ d .close ()
483466
484467 self .config .log .debug ('Rewriting config.' )
485468 with self .open ('w' ) as w :
0 commit comments