Skip to content

Commit c4091f7

Browse files
committed
Rewrite entire token_cache.py
1 parent c8f4774 commit c4091f7

File tree

5 files changed

+131
-164
lines changed

5 files changed

+131
-164
lines changed

.pylintrc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
[MESSAGES CONTROL]
2+
good-names=
3+
logger
24
disable=
35
trailing-newlines,
46
useless-object-inheritance

msal_extensions/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
LibsecretPersistence,
1111
)
1212
from .cache_lock import CrossPlatLock
13+
from .token_cache import PersistedTokenCache
1314

1415
if sys.platform.startswith('win'):
1516
from .token_cache import WindowsTokenCache as TokenCache
1617
elif sys.platform.startswith('darwin'):
1718
from .token_cache import OSXTokenCache as TokenCache
1819
else:
19-
from .token_cache import UnencryptedTokenCache as TokenCache
20+
from .token_cache import FileTokenCache as TokenCache

msal_extensions/token_cache.py

Lines changed: 53 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,156 +1,89 @@
11
"""Generic functions and types for working with a TokenCache that is not platform specific."""
22
import os
3-
import sys
43
import warnings
54
import time
65
import errno
7-
import msal
8-
from .cache_lock import CrossPlatLock
6+
import logging
97

10-
if sys.platform.startswith('win'):
11-
from .windows import WindowsDataProtectionAgent
12-
elif sys.platform.startswith('darwin'):
13-
from .osx import Keychain
8+
import msal
149

15-
def _mkdir_p(path):
16-
"""Creates a directory, and any necessary parents.
10+
from .cache_lock import CrossPlatLock
11+
from .persistence import (
12+
_mkdir_p, FilePersistence,
13+
FilePersistenceWithDataProtection, KeychainPersistence)
1714

18-
This implementation based on a Stack Overflow question that can be found here:
19-
https://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python
2015

21-
If the path provided is an existing file, this function raises an exception.
22-
:param path: The directory name that should be created.
23-
"""
24-
try:
25-
os.makedirs(path)
26-
except OSError as exp:
27-
if exp.errno == errno.EEXIST and os.path.isdir(path):
28-
pass
29-
else:
30-
raise
16+
logger = logging.getLogger(__name__)
3117

18+
class PersistedTokenCache(msal.SerializableTokenCache):
19+
"""A token cache using given persistence layer, coordinated by a file lock."""
3220

33-
class FileTokenCache(msal.SerializableTokenCache):
34-
"""Implements basic unprotected SerializableTokenCache to a plain-text file."""
35-
def __init__(self,
36-
cache_location,
37-
lock_location=None):
38-
super(FileTokenCache, self).__init__()
39-
self._cache_location = cache_location
40-
self._lock_location = lock_location or self._cache_location + '.lockfile'
21+
def __init__(self, persistence, lock_location=None):
22+
super(PersistedTokenCache, self).__init__()
23+
self._lock_location = (
24+
os.path.expanduser(lock_location) if lock_location
25+
else persistence.get_location() + ".lockfile")
26+
_mkdir_p(os.path.dirname(self._lock_location))
27+
self._persistence = persistence
4128
self._last_sync = 0 # _last_sync is a Unixtime
29+
self.is_encrypted = persistence.is_encrypted
4230

43-
self._cache_location = os.path.expanduser(self._cache_location)
44-
self._lock_location = os.path.expanduser(self._lock_location)
45-
46-
_mkdir_p(os.path.dirname(self._lock_location))
47-
_mkdir_p(os.path.dirname(self._cache_location))
48-
49-
def _needs_refresh(self):
50-
# type: () -> Bool
51-
"""
52-
Inspects the file holding the encrypted TokenCache to see if a read is necessary.
53-
:return: True if there are changes not reflected in memory, False otherwise.
54-
"""
31+
def _reload_if_necessary(self):
32+
# type: () -> None
33+
"""Reload cache from persistence layer, if necessary"""
5534
try:
56-
updated = os.path.getmtime(self._cache_location)
57-
return self._last_sync < updated
35+
if self._last_sync < self._persistence.time_last_modified():
36+
self.deserialize(self._persistence.load())
37+
self._last_sync = time.time()
5838
except IOError as exp:
5939
if exp.errno != errno.ENOENT:
60-
raise exp
61-
return False
62-
63-
def _write(self, contents):
64-
# type: (str) -> None
65-
"""Handles actually committing the serialized form of this TokenCache to persisted storage.
66-
For types derived of this, class that will be a file, which has the ability to track a last
67-
modified time.
68-
69-
:param contents: The serialized contents of a TokenCache
70-
"""
71-
with open(self._cache_location, 'w+') as handle:
72-
handle.write(contents)
73-
74-
def _read(self):
75-
# type: () -> str
76-
"""Fetches the contents of a file and invokes deserialization."""
77-
with open(self._cache_location, 'r') as handle:
78-
return handle.read()
40+
raise
41+
# Otherwise, from cache's perspective, a nonexistent file is a NO-OP
7942

8043
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
8144
with CrossPlatLock(self._lock_location):
82-
if self._needs_refresh():
83-
try:
84-
self.deserialize(self._read())
85-
except IOError as exp:
86-
if exp.errno != errno.ENOENT:
87-
raise
88-
super(FileTokenCache, self).modify(
45+
self._reload_if_necessary()
46+
super(PersistedTokenCache, self).modify(
8947
credential_type,
9048
old_entry,
9149
new_key_value_pairs=new_key_value_pairs)
92-
self._write(self.serialize())
93-
self._last_sync = os.path.getmtime(self._cache_location)
50+
self._persistence.save(self.serialize())
51+
self._last_sync = time.time()
9452

9553
def find(self, credential_type, **kwargs): # pylint: disable=arguments-differ
9654
with CrossPlatLock(self._lock_location):
97-
if self._needs_refresh():
98-
try:
99-
self.deserialize(self._read())
100-
except IOError as exp:
101-
if exp.errno != errno.ENOENT:
102-
raise
103-
self._last_sync = time.time()
104-
return super(FileTokenCache, self).find(credential_type, **kwargs)
55+
self._reload_if_necessary()
56+
return super(PersistedTokenCache, self).find(credential_type, **kwargs)
10557

10658

107-
class UnencryptedTokenCache(FileTokenCache):
108-
"""An unprotected token cache to default to when no-platform specific option is available."""
109-
def __init__(self, cache_location, **kwargs):
110-
warnings.warn("You are using an unprotected token cache, "
111-
"because an encrypted option is not available for {}".format(sys.platform),
112-
RuntimeWarning)
113-
super(UnencryptedTokenCache, self).__init__(cache_location, **kwargs)
59+
class FileTokenCache(PersistedTokenCache):
60+
"""A token cache which uses plain text file to store your tokens."""
61+
def __init__(self, cache_location, **ignored): # pylint: disable=unused-argument
62+
warnings.warn("You are using an unprotected token cache", RuntimeWarning)
63+
warnings.warn("Use PersistedTokenCache(...) instead", DeprecationWarning)
64+
super(FileTokenCache, self).__init__(FilePersistence(cache_location))
11465

66+
UnencryptedTokenCache = FileTokenCache # For backward compatibility
11567

116-
class WindowsTokenCache(FileTokenCache):
117-
"""A SerializableTokenCache implementation which uses Win32 encryption APIs to protect your
118-
tokens.
119-
"""
120-
def __init__(self, cache_location, entropy='', **kwargs):
121-
super(WindowsTokenCache, self).__init__(cache_location, **kwargs)
122-
self._dp_agent = WindowsDataProtectionAgent(entropy=entropy)
12368

124-
def _write(self, contents):
125-
with open(self._cache_location, 'wb') as handle:
126-
handle.write(self._dp_agent.protect(contents))
69+
class WindowsTokenCache(PersistedTokenCache):
70+
"""A token cache which uses Windows DPAPI to encrypt your tokens."""
71+
def __init__(
72+
self, cache_location, entropy='',
73+
**ignored): # pylint: disable=unused-argument
74+
warnings.warn("Use PersistedTokenCache(...) instead", DeprecationWarning)
75+
super(WindowsTokenCache, self).__init__(
76+
FilePersistenceWithDataProtection(cache_location, entropy=entropy))
12777

128-
def _read(self):
129-
with open(self._cache_location, 'rb') as handle:
130-
cipher_text = handle.read()
131-
return self._dp_agent.unprotect(cipher_text)
132-
133-
134-
class OSXTokenCache(FileTokenCache):
135-
"""A SerializableTokenCache implementation which uses native Keychain libraries to protect your
136-
tokens.
137-
"""
13878

79+
class OSXTokenCache(PersistedTokenCache):
80+
"""A token cache which uses native Keychain libraries to encrypt your tokens."""
13981
def __init__(self,
14082
cache_location,
14183
service_name='Microsoft.Developer.IdentityService',
14284
account_name='MSALCache',
143-
**kwargs):
144-
super(OSXTokenCache, self).__init__(cache_location, **kwargs)
145-
self._service_name = service_name
146-
self._account_name = account_name
147-
148-
def _read(self):
149-
with Keychain() as locker:
150-
return locker.get_generic_password(self._service_name, self._account_name)
151-
152-
def _write(self, contents):
153-
with Keychain() as locker:
154-
locker.set_generic_password(self._service_name, self._account_name, contents)
155-
with open(self._cache_location, "w+") as handle:
156-
handle.write('{} {}'.format(os.getpid(), sys.argv[0]))
85+
**ignored): # pylint: disable=unused-argument
86+
warnings.warn("Use PersistedTokenCache(...) instead", DeprecationWarning)
87+
super(OSXTokenCache, self).__init__(
88+
KeychainPersistence(cache_location, service_name, account_name))
89+

sample/token_cache_sample.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import sys
2+
import logging
3+
import json
4+
5+
from msal_extensions import *
6+
7+
8+
def build_persistence(location, fallback_to_plaintext=False):
9+
"""Build a suitable persistence instance based your current OS"""
10+
if sys.platform.startswith('win'):
11+
return FilePersistenceWithDataProtection(location)
12+
if sys.platform.startswith('darwin'):
13+
return KeychainPersistence(location, "my_service_name", "my_account_name")
14+
if sys.platform.startswith('linux'):
15+
try:
16+
return LibsecretPersistence(
17+
# By using same location as the fall back option below,
18+
# this would override the unencrypted data stored by the
19+
# fall back option. It is probably OK, or even desirable
20+
# (in order to aggressively wipe out plain-text persisted data),
21+
# unless there would frequently be a desktop session and
22+
# a remote ssh session being active simultaneously.
23+
location,
24+
schema_name="my_schema_name",
25+
attributes={"my_attr1": "foo", "my_attr2": "bar"},
26+
)
27+
except: # pylint: disable=bare-except
28+
if not fallback_to_plaintext:
29+
raise
30+
logging.exception("Encryption unavailable. Opting in to plain text.")
31+
return FilePersistence(location)
32+
33+
persistence = build_persistence("token_cache.bin")
34+
print("Is this persistence encrypted?", persistence.is_encrypted)
35+
36+
cache = PersistedTokenCache(persistence)
37+
# Now you can use it in an msal application like this:
38+
# app = msal.PublicClientApplication("my_client_id", token_cache=cache)
39+

tests/test_agnostic_backend.py

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,46 @@
11
import os
22
import shutil
33
import tempfile
4-
import pytest
4+
import sys
5+
56
import msal
7+
import pytest
68

9+
from msal_extensions import *
710

8-
def test_file_token_cache_roundtrip():
9-
from msal_extensions.token_cache import FileTokenCache
1011

12+
@pytest.fixture
13+
def temp_location():
14+
test_folder = tempfile.mkdtemp(prefix="test_token_cache_roundtrip")
15+
yield os.path.join(test_folder, 'token_cache.bin')
16+
shutil.rmtree(test_folder, ignore_errors=True)
17+
18+
19+
def _test_token_cache_roundtrip(cache):
1120
client_id = os.getenv('AZURE_CLIENT_ID')
1221
client_secret = os.getenv('AZURE_CLIENT_SECRET')
1322
if not (client_id and client_secret):
14-
pytest.skip('no credentials present to test FileTokenCache round-trip with.')
15-
16-
test_folder = tempfile.mkdtemp(prefix="msal_extension_test_file_token_cache_roundtrip")
17-
cache_file = os.path.join(test_folder, 'msal.cache')
18-
try:
19-
subject = FileTokenCache(cache_location=cache_file)
20-
app = msal.ConfidentialClientApplication(
21-
client_id=client_id,
22-
client_credential=client_secret,
23-
token_cache=subject)
24-
desired_scopes = ['https://graph.microsoft.com/.default']
25-
token1 = app.acquire_token_for_client(scopes=desired_scopes)
26-
os.utime(cache_file, None) # Mock having another process update the cache.
27-
token2 = app.acquire_token_silent(scopes=desired_scopes, account=None)
28-
assert token1['access_token'] == token2['access_token']
29-
finally:
30-
shutil.rmtree(test_folder, ignore_errors=True)
31-
32-
33-
def test_current_platform_cache_roundtrip():
23+
pytest.skip('no credentials present to test TokenCache round-trip with.')
24+
25+
app = msal.ConfidentialClientApplication(
26+
client_id=client_id,
27+
client_credential=client_secret,
28+
token_cache=cache)
29+
desired_scopes = ['https://graph.microsoft.com/.default']
30+
token1 = app.acquire_token_for_client(scopes=desired_scopes)
31+
os.utime( # Mock having another process update the cache
32+
cache._persistence.get_location(), None)
33+
token2 = app.acquire_token_silent(scopes=desired_scopes, account=None)
34+
assert token1['access_token'] == token2['access_token']
35+
36+
def test_file_token_cache_roundtrip(temp_location):
37+
from msal_extensions.token_cache import FileTokenCache
38+
_test_token_cache_roundtrip(FileTokenCache(temp_location))
39+
40+
def test_current_platform_cache_roundtrip_with_alias_class(temp_location):
3441
from msal_extensions import TokenCache
35-
client_id = os.getenv('AZURE_CLIENT_ID')
36-
client_secret = os.getenv('AZURE_CLIENT_SECRET')
37-
if not (client_id and client_secret):
38-
pytest.skip('no credentials present to test FileTokenCache round-trip with.')
39-
40-
test_folder = tempfile.mkdtemp(prefix="msal_extension_test_file_token_cache_roundtrip")
41-
cache_file = os.path.join(test_folder, 'msal.cache')
42-
try:
43-
subject = TokenCache(cache_location=cache_file)
44-
app = msal.ConfidentialClientApplication(
45-
client_id=client_id,
46-
client_credential=client_secret,
47-
token_cache=subject)
48-
desired_scopes = ['https://graph.microsoft.com/.default']
49-
token1 = app.acquire_token_for_client(scopes=desired_scopes)
50-
os.utime(cache_file, None) # Mock having another process update the cache.
51-
token2 = app.acquire_token_silent(scopes=desired_scopes, account=None)
52-
assert token1['access_token'] == token2['access_token']
53-
finally:
54-
shutil.rmtree(test_folder, ignore_errors=True)
42+
_test_token_cache_roundtrip(TokenCache(temp_location))
43+
44+
def test_persisted_token_cache(temp_location):
45+
_test_token_cache_roundtrip(PersistedTokenCache(FilePersistence(temp_location)))
46+

0 commit comments

Comments
 (0)