11"""Internal (private) Utilities Module."""
22
3+ import copy
34import logging
45import math
56import os
67import random
7- from typing import Any , Dict , Generator , List , Optional , Tuple
8+ from typing import Any , Dict , Generator , List , Optional , Tuple , Union
89
910import boto3 # type: ignore
1011import botocore .config # type: ignore
1718_logger : logging .Logger = logging .getLogger (__name__ )
1819
1920
20- def ensure_session (session : Optional [boto3 .Session ] = None ) -> boto3 .Session :
21+ def ensure_session (session : Optional [Union [ boto3 .Session , Dict [ str , Optional [ str ]]] ] = None ) -> boto3 .Session :
2122 """Ensure that a valid boto3.Session will be returned."""
23+ if isinstance (session , dict ): # Primitives received
24+ return boto3_from_primitives (primitives = session )
2225 if session is not None :
2326 return session
2427 # Ensure the boto3's default session is used so that its parameters can be
@@ -28,6 +31,30 @@ def ensure_session(session: Optional[boto3.Session] = None) -> boto3.Session:
2831 return boto3 .Session () # pragma: no cover
2932
3033
34+ def boto3_to_primitives (boto3_session : Optional [boto3 .Session ] = None ) -> Dict [str , Optional [str ]]:
35+ """Convert Boto3 Session to Python primitives."""
36+ _boto3_session : boto3 .Session = ensure_session (session = boto3_session )
37+ credentials = _boto3_session .get_credentials ()
38+ return {
39+ "aws_access_key_id" : getattr (credentials , "access_key" , None ),
40+ "aws_secret_access_key" : getattr (credentials , "secret_key" , None ),
41+ "aws_session_token" : getattr (credentials , "token" , None ),
42+ "region_name" : _boto3_session .region_name ,
43+ "profile_name" : _boto3_session .profile_name ,
44+ }
45+
46+
47+ def boto3_from_primitives (primitives : Dict [str , Optional [str ]] = None ) -> boto3 .Session :
48+ """Convert Python primitives to Boto3 Session."""
49+ if primitives is None :
50+ return boto3 .DEFAULT_SESSION # pragma: no cover
51+ _primitives : Dict [str , Optional [str ]] = copy .deepcopy (primitives )
52+ profile_name : Optional [str ] = _primitives .get ("profile_name" , None )
53+ _primitives ["profile_name" ] = None if profile_name in (None , "default" ) else profile_name
54+ args : Dict [str , str ] = {k : v for k , v in _primitives .items () if v is not None }
55+ return boto3 .Session (** args )
56+
57+
3158def client (service_name : str , session : Optional [boto3 .Session ] = None ) -> boto3 .client :
3259 """Create a valid boto3.client."""
3360 return ensure_session (session = session ).client (
@@ -63,6 +90,8 @@ def parse_path(path: str) -> Tuple[str, str]:
6390 >>> bucket, key = parse_path('s3://bucket/key')
6491
6592 """
93+ if path .startswith ("s3://" ) is False :
94+ raise exceptions .InvalidArgumentValue (f"'{ path } ' is not a valid path. It MUST start with 's3://'" )
6695 parts = path .replace ("s3://" , "" ).split ("/" , 1 )
6796 bucket : str = parts [0 ]
6897 key : str = ""
@@ -139,7 +168,8 @@ def chunkify(lst: List[Any], num_chunks: int = 1, max_length: Optional[int] = No
139168
140169
141170def get_fs (
142- session : Optional [boto3 .Session ] = None , s3_additional_kwargs : Optional [Dict [str , str ]] = None
171+ session : Optional [Union [boto3 .Session , Dict [str , Optional [str ]]]] = None ,
172+ s3_additional_kwargs : Optional [Dict [str , str ]] = None ,
143173) -> s3fs .S3FileSystem :
144174 """Build a S3FileSystem from a given boto3 session."""
145175 fs : s3fs .S3FileSystem = s3fs .S3FileSystem (
0 commit comments