Skip to content

Commit 8605a95

Browse files
authored
Unified exception when load() is called before save() (#67)
* Unified exception when load() is called before save() * Go with an explicit PersistenceNotFound exception * Caller has a much simpler pattern to handle read errors * Change "where" to "location" * Test cases for PersistentNotFound
1 parent c537cf2 commit 8605a95

File tree

3 files changed

+128
-18
lines changed

3 files changed

+128
-18
lines changed

msal_extensions/persistence.py

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import abc
1010
import os
1111
import errno
12+
import logging
1213
try:
1314
from pathlib import Path # Built-in in Python 3
1415
except:
@@ -21,6 +22,9 @@
2122
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore
2223

2324

25+
logger = logging.getLogger(__name__)
26+
27+
2428
def _mkdir_p(path):
2529
"""Creates a directory, and any necessary parents.
2630
@@ -41,6 +45,16 @@ def _mkdir_p(path):
4145
raise
4246

4347

48+
# We do not aim to wrap every os-specific exception.
49+
# Here we define only the most common one,
50+
# otherwise caller would need to catch os-specific persistence exceptions.
51+
class PersistenceNotFound(OSError):
52+
def __init__(
53+
self,
54+
err_no=errno.ENOENT, message="Persistence not found", location=None):
55+
super(PersistenceNotFound, self).__init__(err_no, message, location)
56+
57+
4458
class BasePersistence(ABC):
4559
"""An abstract persistence defining the common interface of this family"""
4660

@@ -55,12 +69,18 @@ def save(self, content):
5569
@abc.abstractmethod
5670
def load(self):
5771
# type: () -> str
58-
"""Load content from this persistence"""
72+
"""Load content from this persistence.
73+
74+
Could raise PersistenceNotFound if no save() was called before.
75+
"""
5976
raise NotImplementedError
6077

6178
@abc.abstractmethod
6279
def time_last_modified(self):
63-
"""Get the last time when this persistence has been modified"""
80+
"""Get the last time when this persistence has been modified.
81+
82+
Could raise PersistenceNotFound if no save() was called before.
83+
"""
6484
raise NotImplementedError
6585

6686
@abc.abstractmethod
@@ -87,11 +107,32 @@ def save(self, content):
87107
def load(self):
88108
# type: () -> str
89109
"""Load content from this persistence"""
90-
with open(self._location, 'r') as handle:
91-
return handle.read()
110+
try:
111+
with open(self._location, 'r') as handle:
112+
return handle.read()
113+
except EnvironmentError as exp: # EnvironmentError in Py 2.7 works across platform
114+
if exp.errno == errno.ENOENT:
115+
raise PersistenceNotFound(
116+
message=(
117+
"Persistence not initialized. "
118+
"You can recover by calling a save() first."),
119+
location=self._location,
120+
)
121+
raise
122+
92123

93124
def time_last_modified(self):
94-
return os.path.getmtime(self._location)
125+
try:
126+
return os.path.getmtime(self._location)
127+
except EnvironmentError as exp: # EnvironmentError in Py 2.7 works across platform
128+
if exp.errno == errno.ENOENT:
129+
raise PersistenceNotFound(
130+
message=(
131+
"Persistence not initialized. "
132+
"You can recover by calling a save() first."),
133+
location=self._location,
134+
)
135+
raise
95136

96137
def touch(self):
97138
"""To touch this file-based persistence without writing content into it"""
@@ -115,13 +156,28 @@ def __init__(self, location, entropy=''):
115156

116157
def save(self, content):
117158
# type: (str) -> None
159+
data = self._dp_agent.protect(content)
118160
with open(self._location, 'wb+') as handle:
119-
handle.write(self._dp_agent.protect(content))
161+
handle.write(data)
120162

121163
def load(self):
122164
# type: () -> str
123-
with open(self._location, 'rb') as handle:
124-
return self._dp_agent.unprotect(handle.read())
165+
try:
166+
with open(self._location, 'rb') as handle:
167+
data = handle.read()
168+
return self._dp_agent.unprotect(data)
169+
except EnvironmentError as exp: # EnvironmentError in Py 2.7 works across platform
170+
if exp.errno == errno.ENOENT:
171+
raise PersistenceNotFound(
172+
message=(
173+
"Persistence not initialized. "
174+
"You can recover by calling a save() first."),
175+
location=self._location,
176+
)
177+
logger.exception(
178+
"DPAPI error likely caused by file content not previously encrypted. "
179+
"App developer should migrate by calling save(plaintext) first.")
180+
raise
125181

126182

127183
class KeychainPersistence(BasePersistence):
@@ -136,9 +192,10 @@ def __init__(self, signal_location, service_name, account_name):
136192
"""
137193
if not (service_name and account_name): # It would hang on OSX
138194
raise ValueError("service_name and account_name are required")
139-
from .osx import Keychain # pylint: disable=import-outside-toplevel
195+
from .osx import Keychain, KeychainError # pylint: disable=import-outside-toplevel
140196
self._file_persistence = FilePersistence(signal_location) # Favor composition
141197
self._Keychain = Keychain # pylint: disable=invalid-name
198+
self._KeychainError = KeychainError # pylint: disable=invalid-name
142199
self._service_name = service_name
143200
self._account_name = account_name
144201

@@ -150,8 +207,21 @@ def save(self, content):
150207

151208
def load(self):
152209
with self._Keychain() as locker:
153-
return locker.get_generic_password(
154-
self._service_name, self._account_name)
210+
try:
211+
return locker.get_generic_password(
212+
self._service_name, self._account_name)
213+
except self._KeychainError as ex:
214+
if ex.exit_status == self._KeychainError.ITEM_NOT_FOUND:
215+
# This happens when a load() is called before a save().
216+
# We map it into cross-platform error for unified catching.
217+
raise PersistenceNotFound(
218+
location="Service:{} Account:{}".format(
219+
self._service_name, self._account_name),
220+
message=(
221+
"Keychain persistence not initialized. "
222+
"You can recover by call a save() first."),
223+
)
224+
raise # We do not intend to hide any other underlying exceptions
155225

156226
def time_last_modified(self):
157227
return self._file_persistence.time_last_modified()
@@ -188,7 +258,14 @@ def save(self, content):
188258
self._file_persistence.touch() # For time_last_modified()
189259

190260
def load(self):
191-
return self._agent.load()
261+
data = self._agent.load()
262+
if data is None:
263+
# Lower level libsecret would return None when found nothing. Here
264+
# in persistence layer, we convert it to a unified error for consistence.
265+
raise PersistenceNotFound(message=(
266+
"Keyring persistence not initialized. "
267+
"You can recover by call a save() first."))
268+
return data
192269

193270
def time_last_modified(self):
194271
return self._file_persistence.time_last_modified()

msal_extensions/token_cache.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import os
33
import warnings
44
import time
5-
import errno
65
import logging
76

87
import msal
98

109
from .cache_lock import CrossPlatLock
1110
from .persistence import (
12-
_mkdir_p, FilePersistence,
11+
_mkdir_p, PersistenceNotFound, FilePersistence,
1312
FilePersistenceWithDataProtection, KeychainPersistence)
1413

1514

@@ -35,10 +34,10 @@ def _reload_if_necessary(self):
3534
if self._last_sync < self._persistence.time_last_modified():
3635
self.deserialize(self._persistence.load())
3736
self._last_sync = time.time()
38-
except EnvironmentError as exp:
39-
if exp.errno != errno.ENOENT:
40-
raise
41-
# Otherwise, from cache's perspective, a nonexistent file is a NO-OP
37+
except PersistenceNotFound:
38+
# From cache's perspective, a nonexistent persistence is a NO-OP.
39+
pass
40+
# However, existing data unable to be decrypted will still be bubbled up.
4241

4342
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
4443
with CrossPlatLock(self._lock_location):

tests/test_persistence.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,45 @@ def _test_persistence_roundtrip(persistence):
2626
persistence.save(payload)
2727
assert persistence.load() == payload
2828

29+
def _test_nonexistent_persistence(persistence):
30+
with pytest.raises(PersistenceNotFound):
31+
persistence.load()
32+
with pytest.raises(PersistenceNotFound):
33+
persistence.time_last_modified()
34+
2935
def test_file_persistence(temp_location):
3036
_test_persistence_roundtrip(FilePersistence(temp_location))
3137

38+
def test_nonexistent_file_persistence(temp_location):
39+
_test_nonexistent_persistence(FilePersistence(temp_location))
40+
3241
@pytest.mark.skipif(
3342
is_running_on_travis_ci or not sys.platform.startswith('win'),
3443
reason="Requires Windows Desktop")
3544
def test_file_persistence_with_data_protection(temp_location):
3645
_test_persistence_roundtrip(FilePersistenceWithDataProtection(temp_location))
3746

47+
@pytest.mark.skipif(
48+
is_running_on_travis_ci or not sys.platform.startswith('win'),
49+
reason="Requires Windows Desktop")
50+
def test_nonexistent_file_persistence_with_data_protection(temp_location):
51+
_test_nonexistent_persistence(FilePersistenceWithDataProtection(temp_location))
52+
3853
@pytest.mark.skipif(
3954
not sys.platform.startswith('darwin'),
4055
reason="Requires OSX. Whether running on TRAVIS CI does not seem to matter.")
4156
def test_keychain_persistence(temp_location):
4257
_test_persistence_roundtrip(KeychainPersistence(
4358
temp_location, "my_service_name", "my_account_name"))
4459

60+
@pytest.mark.skipif(
61+
not sys.platform.startswith('darwin'),
62+
reason="Requires OSX. Whether running on TRAVIS CI does not seem to matter.")
63+
def test_nonexistent_keychain_persistence(temp_location):
64+
random_service_name = random_account_name = str(id(temp_location))
65+
_test_nonexistent_persistence(
66+
KeychainPersistence(temp_location, random_service_name, random_account_name))
67+
4568
@pytest.mark.skipif(
4669
is_running_on_travis_ci or not sys.platform.startswith('linux'),
4770
reason="Requires Linux Desktop. Headless or SSH session won't work.")
@@ -52,3 +75,14 @@ def test_libsecret_persistence(temp_location):
5275
{"my_attr_1": "foo", "my_attr_2": "bar"},
5376
))
5477

78+
@pytest.mark.skipif(
79+
is_running_on_travis_ci or not sys.platform.startswith('linux'),
80+
reason="Requires Linux Desktop. Headless or SSH session won't work.")
81+
def test_nonexistent_libsecret_persistence(temp_location):
82+
random_schema_name = random_value = str(id(temp_location))
83+
_test_nonexistent_persistence(LibsecretPersistence(
84+
temp_location,
85+
random_schema_name,
86+
{"my_attr_1": random_value, "my_attr_2": random_value},
87+
))
88+

0 commit comments

Comments
 (0)