99import abc
1010import os
1111import errno
12+ import logging
1213try :
1314 from pathlib import Path # Built-in in Python 3
1415except :
2122 ABC = abc .ABCMeta ("ABC" , (object ,), {"__slots__" : ()}) # type: ignore
2223
2324
25+ logger = logging .getLogger (__name__ )
26+
27+
2428def _mkdir_p (path ):
2529 """Creates a directory, and any necessary parents.
2630
@@ -41,6 +45,20 @@ def _mkdir_p(path):
4145 raise
4246
4347
48+ # We do not aim to wrap every os-specific exception.
49+ # Here we define only the most common one,
50+ # otherwise caller would need to catch os-specific persistence exceptions.
51+ class PersistenceNotFound (IOError ): # Use IOError rather than OSError as base,
52+ # because historically an IOError was bubbled up and expected.
53+ # https://github.com/AzureAD/microsoft-authentication-extensions-for-python/blob/0.2.2/msal_extensions/token_cache.py#L38
54+ # Now we want to maintain backward compatibility even when using Python 2.x
55+ # It makes no difference in Python 3.3+ where IOError is an alias of OSError.
56+ def __init__ (
57+ self ,
58+ err_no = errno .ENOENT , message = "Persistence not found" , location = None ):
59+ super (PersistenceNotFound , self ).__init__ (err_no , message , location )
60+
61+
4462class BasePersistence (ABC ):
4563 """An abstract persistence defining the common interface of this family"""
4664
@@ -55,12 +73,18 @@ def save(self, content):
5573 @abc .abstractmethod
5674 def load (self ):
5775 # type: () -> str
58- """Load content from this persistence"""
76+ """Load content from this persistence.
77+
78+ Could raise PersistenceNotFound if no save() was called before.
79+ """
5980 raise NotImplementedError
6081
6182 @abc .abstractmethod
6283 def time_last_modified (self ):
63- """Get the last time when this persistence has been modified"""
84+ """Get the last time when this persistence has been modified.
85+
86+ Could raise PersistenceNotFound if no save() was called before.
87+ """
6488 raise NotImplementedError
6589
6690 @abc .abstractmethod
@@ -87,11 +111,32 @@ def save(self, content):
87111 def load (self ):
88112 # type: () -> str
89113 """Load content from this persistence"""
90- with open (self ._location , 'r' ) as handle :
91- return handle .read ()
114+ try :
115+ with open (self ._location , 'r' ) as handle :
116+ return handle .read ()
117+ except EnvironmentError as exp : # EnvironmentError in Py 2.7 works across platform
118+ if exp .errno == errno .ENOENT :
119+ raise PersistenceNotFound (
120+ message = (
121+ "Persistence not initialized. "
122+ "You can recover by calling a save() first." ),
123+ location = self ._location ,
124+ )
125+ raise
126+
92127
93128 def time_last_modified (self ):
94- return os .path .getmtime (self ._location )
129+ try :
130+ return os .path .getmtime (self ._location )
131+ except EnvironmentError as exp : # EnvironmentError in Py 2.7 works across platform
132+ if exp .errno == errno .ENOENT :
133+ raise PersistenceNotFound (
134+ message = (
135+ "Persistence not initialized. "
136+ "You can recover by calling a save() first." ),
137+ location = self ._location ,
138+ )
139+ raise
95140
96141 def touch (self ):
97142 """To touch this file-based persistence without writing content into it"""
@@ -115,13 +160,28 @@ def __init__(self, location, entropy=''):
115160
116161 def save (self , content ):
117162 # type: (str) -> None
163+ data = self ._dp_agent .protect (content )
118164 with open (self ._location , 'wb+' ) as handle :
119- handle .write (self . _dp_agent . protect ( content ) )
165+ handle .write (data )
120166
121167 def load (self ):
122168 # type: () -> str
123- with open (self ._location , 'rb' ) as handle :
124- return self ._dp_agent .unprotect (handle .read ())
169+ try :
170+ with open (self ._location , 'rb' ) as handle :
171+ data = handle .read ()
172+ return self ._dp_agent .unprotect (data )
173+ except EnvironmentError as exp : # EnvironmentError in Py 2.7 works across platform
174+ if exp .errno == errno .ENOENT :
175+ raise PersistenceNotFound (
176+ message = (
177+ "Persistence not initialized. "
178+ "You can recover by calling a save() first." ),
179+ location = self ._location ,
180+ )
181+ logger .exception (
182+ "DPAPI error likely caused by file content not previously encrypted. "
183+ "App developer should migrate by calling save(plaintext) first." )
184+ raise
125185
126186
127187class KeychainPersistence (BasePersistence ):
@@ -136,9 +196,10 @@ def __init__(self, signal_location, service_name, account_name):
136196 """
137197 if not (service_name and account_name ): # It would hang on OSX
138198 raise ValueError ("service_name and account_name are required" )
139- from .osx import Keychain # pylint: disable=import-outside-toplevel
199+ from .osx import Keychain , KeychainError # pylint: disable=import-outside-toplevel
140200 self ._file_persistence = FilePersistence (signal_location ) # Favor composition
141201 self ._Keychain = Keychain # pylint: disable=invalid-name
202+ self ._KeychainError = KeychainError # pylint: disable=invalid-name
142203 self ._service_name = service_name
143204 self ._account_name = account_name
144205
@@ -150,8 +211,21 @@ def save(self, content):
150211
151212 def load (self ):
152213 with self ._Keychain () as locker :
153- return locker .get_generic_password (
154- self ._service_name , self ._account_name )
214+ try :
215+ return locker .get_generic_password (
216+ self ._service_name , self ._account_name )
217+ except self ._KeychainError as ex :
218+ if ex .exit_status == self ._KeychainError .ITEM_NOT_FOUND :
219+ # This happens when a load() is called before a save().
220+ # We map it into cross-platform error for unified catching.
221+ raise PersistenceNotFound (
222+ location = "Service:{} Account:{}" .format (
223+ self ._service_name , self ._account_name ),
224+ message = (
225+ "Keychain persistence not initialized. "
226+ "You can recover by call a save() first." ),
227+ )
228+ raise # We do not intend to hide any other underlying exceptions
155229
156230 def time_last_modified (self ):
157231 return self ._file_persistence .time_last_modified ()
@@ -188,7 +262,14 @@ def save(self, content):
188262 self ._file_persistence .touch () # For time_last_modified()
189263
190264 def load (self ):
191- return self ._agent .load ()
265+ data = self ._agent .load ()
266+ if data is None :
267+ # Lower level libsecret would return None when found nothing. Here
268+ # in persistence layer, we convert it to a unified error for consistence.
269+ raise PersistenceNotFound (message = (
270+ "Keyring persistence not initialized. "
271+ "You can recover by call a save() first." ))
272+ return data
192273
193274 def time_last_modified (self ):
194275 return self ._file_persistence .time_last_modified ()
0 commit comments