@@ -28,7 +28,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
2828 _cosmosdb_account_url : str | None
2929 _connection_string : str | None
3030 _database_name : str
31- _current_container : str | None
31+ container_name : str | None
3232 _encoding : str
3333
3434 def __init__ (
@@ -58,7 +58,7 @@ def __init__(
5858 self ._database_name = database_name
5959 self ._connection_string = connection_string
6060 self ._cosmosdb_account_url = cosmosdb_account_url
61- self ._current_container = current_container
61+ self .container_name = current_container
6262 self ._cosmosdb_account_name = (
6363 cosmosdb_account_url .split ("//" )[1 ].split ("." )[0 ]
6464 if cosmosdb_account_url
@@ -74,7 +74,7 @@ def __init__(
7474 self ._database_name ,
7575 )
7676 self ._create_database ()
77- if self ._current_container :
77+ if self .container_name :
7878 self ._create_container ()
7979
8080 def _create_database (self ) -> None :
@@ -117,7 +117,7 @@ def find(
117117
118118 log .info (
119119 "search container %s for documents matching %s" ,
120- self ._current_container ,
120+ self .container_name ,
121121 file_pattern .pattern ,
122122 )
123123
@@ -130,7 +130,7 @@ def item_filter(item: dict[str, Any]) -> bool:
130130
131131 try :
132132 container_client = self ._database_client .get_container_client (
133- str (self ._current_container )
133+ str (self .container_name )
134134 )
135135 query = "SELECT * FROM c WHERE CONTAINS(c.id, @pattern)"
136136 parameters : list [dict [str , Any ]] = [
@@ -173,9 +173,9 @@ async def get(
173173 ) -> Any :
174174 """Get a file in the database for the given key."""
175175 try :
176- if self ._current_container :
176+ if self .container_name :
177177 container_client = self ._database_client .get_container_client (
178- self ._current_container
178+ self .container_name
179179 )
180180 item = container_client .read_item (item = key , partition_key = key )
181181 item_body = item .get ("body" )
@@ -195,9 +195,9 @@ async def get(
195195 async def set (self , key : str , value : Any , encoding : str | None = None ) -> None :
196196 """Set a file in the database for the given key."""
197197 try :
198- if self ._current_container :
198+ if self .container_name :
199199 container_client = self ._database_client .get_container_client (
200- self ._current_container
200+ self .container_name
201201 )
202202 if isinstance (value , bytes ):
203203 value_df = pd .read_parquet (BytesIO (value ))
@@ -217,29 +217,29 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
217217
218218 async def has (self , key : str ) -> bool :
219219 """Check if the given file exists in the cosmosdb storage."""
220- if self ._current_container :
220+ if self .container_name :
221221 container_client = self ._database_client .get_container_client (
222- self ._current_container
222+ self .container_name
223223 )
224224 item_names = [item ["id" ] for item in container_client .read_all_items ()]
225225 return key in item_names
226226 return False
227227
228228 async def delete (self , key : str ) -> None :
229229 """Delete the given file from the cosmosdb storage."""
230- if self ._current_container :
230+ if self .container_name :
231231 container_client = self ._database_client .get_container_client (
232- self ._current_container
232+ self .container_name
233233 )
234234 container_client .delete_item (item = key , partition_key = key )
235235
236236 # Function currently deletes all items within the current container, then deletes the container itself.
237237 # TODO: Decide the granularity of deletion (e.g. delete all items within the current container, delete the current container, delete the current database)
238238 async def clear (self ) -> None :
239239 """Clear the cosmosdb storage."""
240- if self ._current_container :
240+ if self .container_name :
241241 container_client = self ._database_client .get_container_client (
242- self ._current_container
242+ self .container_name
243243 )
244244 for item in container_client .read_all_items ():
245245 item_id = item ["id" ]
@@ -257,24 +257,24 @@ def child(self, name: str | None) -> PipelineStorage:
257257
258258 def _create_container (self ) -> None :
259259 """Create a container for the current container name if it doesn't exist."""
260- if self ._current_container :
260+ if self .container_name :
261261 partition_key = PartitionKey (path = "/id" , kind = "Hash" )
262262 self ._database_client .create_container_if_not_exists (
263- id = self ._current_container ,
263+ id = self .container_name ,
264264 partition_key = partition_key ,
265265 )
266266
267267 def _delete_container (self ) -> None :
268268 """Delete the container with the current container name if it exists."""
269- if self ._container_exists () and self ._current_container :
270- self ._database_client .delete_container (self ._current_container )
269+ if self ._container_exists () and self .container_name :
270+ self ._database_client .delete_container (self .container_name )
271271
272272 def _container_exists (self ) -> bool :
273273 """Check if the container with the current container name exists."""
274274 container_names = [
275275 container ["id" ] for container in self ._database_client .list_containers ()
276276 ]
277- return self ._current_container in container_names
277+ return self .container_name in container_names
278278
279279
280280# TODO remove this helper function and have the factory instantiate the class directly
0 commit comments