|
1 | 1 | """Generic functions and types for working with a TokenCache that is not platform specific.""" |
2 | 2 | import os |
3 | | -import sys |
4 | 3 | import warnings |
5 | 4 | import time |
6 | 5 | import errno |
7 | | -import msal |
8 | | -from .cache_lock import CrossPlatLock |
| 6 | +import logging |
9 | 7 |
|
10 | | -if sys.platform.startswith('win'): |
11 | | - from .windows import WindowsDataProtectionAgent |
12 | | -elif sys.platform.startswith('darwin'): |
13 | | - from .osx import Keychain |
| 8 | +import msal |
14 | 9 |
|
15 | | -def _mkdir_p(path): |
16 | | - """Creates a directory, and any necessary parents. |
| 10 | +from .cache_lock import CrossPlatLock |
| 11 | +from .persistence import ( |
| 12 | + _mkdir_p, FilePersistence, |
| 13 | + FilePersistenceWithDataProtection, KeychainPersistence) |
17 | 14 |
|
18 | | - This implementation based on a Stack Overflow question that can be found here: |
19 | | - https://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python |
20 | 15 |
|
21 | | - If the path provided is an existing file, this function raises an exception. |
22 | | - :param path: The directory name that should be created. |
23 | | - """ |
24 | | - try: |
25 | | - os.makedirs(path) |
26 | | - except OSError as exp: |
27 | | - if exp.errno == errno.EEXIST and os.path.isdir(path): |
28 | | - pass |
29 | | - else: |
30 | | - raise |
| 16 | +logger = logging.getLogger(__name__) |
31 | 17 |
|
| 18 | +class PersistedTokenCache(msal.SerializableTokenCache): |
| 19 | + """A token cache using given persistence layer, coordinated by a file lock.""" |
32 | 20 |
|
33 | | -class FileTokenCache(msal.SerializableTokenCache): |
34 | | - """Implements basic unprotected SerializableTokenCache to a plain-text file.""" |
35 | | - def __init__(self, |
36 | | - cache_location, |
37 | | - lock_location=None): |
38 | | - super(FileTokenCache, self).__init__() |
39 | | - self._cache_location = cache_location |
40 | | - self._lock_location = lock_location or self._cache_location + '.lockfile' |
| 21 | + def __init__(self, persistence, lock_location=None): |
| 22 | + super(PersistedTokenCache, self).__init__() |
| 23 | + self._lock_location = ( |
| 24 | + os.path.expanduser(lock_location) if lock_location |
| 25 | + else persistence.get_location() + ".lockfile") |
| 26 | + _mkdir_p(os.path.dirname(self._lock_location)) |
| 27 | + self._persistence = persistence |
41 | 28 | self._last_sync = 0 # _last_sync is a Unixtime |
| 29 | + self.is_encrypted = persistence.is_encrypted |
42 | 30 |
|
43 | | - self._cache_location = os.path.expanduser(self._cache_location) |
44 | | - self._lock_location = os.path.expanduser(self._lock_location) |
45 | | - |
46 | | - _mkdir_p(os.path.dirname(self._lock_location)) |
47 | | - _mkdir_p(os.path.dirname(self._cache_location)) |
48 | | - |
49 | | - def _needs_refresh(self): |
50 | | - # type: () -> Bool |
51 | | - """ |
52 | | - Inspects the file holding the encrypted TokenCache to see if a read is necessary. |
53 | | - :return: True if there are changes not reflected in memory, False otherwise. |
54 | | - """ |
| 31 | + def _reload_if_necessary(self): |
| 32 | + # type: () -> None |
| 33 | + """Reload cache from persistence layer, if necessary""" |
55 | 34 | try: |
56 | | - updated = os.path.getmtime(self._cache_location) |
57 | | - return self._last_sync < updated |
| 35 | + if self._last_sync < self._persistence.time_last_modified(): |
| 36 | + self.deserialize(self._persistence.load()) |
| 37 | + self._last_sync = time.time() |
58 | 38 | except IOError as exp: |
59 | 39 | if exp.errno != errno.ENOENT: |
60 | | - raise exp |
61 | | - return False |
62 | | - |
63 | | - def _write(self, contents): |
64 | | - # type: (str) -> None |
65 | | - """Handles actually committing the serialized form of this TokenCache to persisted storage. |
66 | | - For types derived of this, class that will be a file, which has the ability to track a last |
67 | | - modified time. |
68 | | -
|
69 | | - :param contents: The serialized contents of a TokenCache |
70 | | - """ |
71 | | - with open(self._cache_location, 'w+') as handle: |
72 | | - handle.write(contents) |
73 | | - |
74 | | - def _read(self): |
75 | | - # type: () -> str |
76 | | - """Fetches the contents of a file and invokes deserialization.""" |
77 | | - with open(self._cache_location, 'r') as handle: |
78 | | - return handle.read() |
| 40 | + raise |
| 41 | + # Otherwise, from cache's perspective, a nonexistent file is a NO-OP |
79 | 42 |
|
80 | 43 | def modify(self, credential_type, old_entry, new_key_value_pairs=None): |
81 | 44 | with CrossPlatLock(self._lock_location): |
82 | | - if self._needs_refresh(): |
83 | | - try: |
84 | | - self.deserialize(self._read()) |
85 | | - except IOError as exp: |
86 | | - if exp.errno != errno.ENOENT: |
87 | | - raise |
88 | | - super(FileTokenCache, self).modify( |
| 45 | + self._reload_if_necessary() |
| 46 | + super(PersistedTokenCache, self).modify( |
89 | 47 | credential_type, |
90 | 48 | old_entry, |
91 | 49 | new_key_value_pairs=new_key_value_pairs) |
92 | | - self._write(self.serialize()) |
93 | | - self._last_sync = os.path.getmtime(self._cache_location) |
| 50 | + self._persistence.save(self.serialize()) |
| 51 | + self._last_sync = time.time() |
94 | 52 |
|
95 | 53 | def find(self, credential_type, **kwargs): # pylint: disable=arguments-differ |
96 | 54 | with CrossPlatLock(self._lock_location): |
97 | | - if self._needs_refresh(): |
98 | | - try: |
99 | | - self.deserialize(self._read()) |
100 | | - except IOError as exp: |
101 | | - if exp.errno != errno.ENOENT: |
102 | | - raise |
103 | | - self._last_sync = time.time() |
104 | | - return super(FileTokenCache, self).find(credential_type, **kwargs) |
| 55 | + self._reload_if_necessary() |
| 56 | + return super(PersistedTokenCache, self).find(credential_type, **kwargs) |
105 | 57 |
|
106 | 58 |
|
107 | | -class UnencryptedTokenCache(FileTokenCache): |
108 | | - """An unprotected token cache to default to when no-platform specific option is available.""" |
109 | | - def __init__(self, cache_location, **kwargs): |
110 | | - warnings.warn("You are using an unprotected token cache, " |
111 | | - "because an encrypted option is not available for {}".format(sys.platform), |
112 | | - RuntimeWarning) |
113 | | - super(UnencryptedTokenCache, self).__init__(cache_location, **kwargs) |
| 59 | +class FileTokenCache(PersistedTokenCache): |
| 60 | + """A token cache which uses plain text file to store your tokens.""" |
| 61 | + def __init__(self, cache_location, **ignored): # pylint: disable=unused-argument |
| 62 | + warnings.warn("You are using an unprotected token cache", RuntimeWarning) |
| 63 | + warnings.warn("Use PersistedTokenCache(...) instead", DeprecationWarning) |
| 64 | + super(FileTokenCache, self).__init__(FilePersistence(cache_location)) |
114 | 65 |
|
| 66 | +UnencryptedTokenCache = FileTokenCache # For backward compatibility |
115 | 67 |
|
116 | | -class WindowsTokenCache(FileTokenCache): |
117 | | - """A SerializableTokenCache implementation which uses Win32 encryption APIs to protect your |
118 | | - tokens. |
119 | | - """ |
120 | | - def __init__(self, cache_location, entropy='', **kwargs): |
121 | | - super(WindowsTokenCache, self).__init__(cache_location, **kwargs) |
122 | | - self._dp_agent = WindowsDataProtectionAgent(entropy=entropy) |
123 | 68 |
|
124 | | - def _write(self, contents): |
125 | | - with open(self._cache_location, 'wb') as handle: |
126 | | - handle.write(self._dp_agent.protect(contents)) |
| 69 | +class WindowsTokenCache(PersistedTokenCache): |
| 70 | + """A token cache which uses Windows DPAPI to encrypt your tokens.""" |
| 71 | + def __init__( |
| 72 | + self, cache_location, entropy='', |
| 73 | + **ignored): # pylint: disable=unused-argument |
| 74 | + warnings.warn("Use PersistedTokenCache(...) instead", DeprecationWarning) |
| 75 | + super(WindowsTokenCache, self).__init__( |
| 76 | + FilePersistenceWithDataProtection(cache_location, entropy=entropy)) |
127 | 77 |
|
128 | | - def _read(self): |
129 | | - with open(self._cache_location, 'rb') as handle: |
130 | | - cipher_text = handle.read() |
131 | | - return self._dp_agent.unprotect(cipher_text) |
132 | | - |
133 | | - |
134 | | -class OSXTokenCache(FileTokenCache): |
135 | | - """A SerializableTokenCache implementation which uses native Keychain libraries to protect your |
136 | | - tokens. |
137 | | - """ |
138 | 78 |
|
| 79 | +class OSXTokenCache(PersistedTokenCache): |
| 80 | + """A token cache which uses native Keychain libraries to encrypt your tokens.""" |
139 | 81 | def __init__(self, |
140 | 82 | cache_location, |
141 | 83 | service_name='Microsoft.Developer.IdentityService', |
142 | 84 | account_name='MSALCache', |
143 | | - **kwargs): |
144 | | - super(OSXTokenCache, self).__init__(cache_location, **kwargs) |
145 | | - self._service_name = service_name |
146 | | - self._account_name = account_name |
147 | | - |
148 | | - def _read(self): |
149 | | - with Keychain() as locker: |
150 | | - return locker.get_generic_password(self._service_name, self._account_name) |
151 | | - |
152 | | - def _write(self, contents): |
153 | | - with Keychain() as locker: |
154 | | - locker.set_generic_password(self._service_name, self._account_name, contents) |
155 | | - with open(self._cache_location, "w+") as handle: |
156 | | - handle.write('{} {}'.format(os.getpid(), sys.argv[0])) |
| 85 | + **ignored): # pylint: disable=unused-argument |
| 86 | + warnings.warn("Use PersistedTokenCache(...) instead", DeprecationWarning) |
| 87 | + super(OSXTokenCache, self).__init__( |
| 88 | + KeychainPersistence(cache_location, service_name, account_name)) |
| 89 | + |
0 commit comments