|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import logging
|
| 4 | +import re |
| 5 | +from abc import ABC, abstractmethod |
| 6 | +from base64 import b64encode |
4 | 7 | from ssl import create_default_context
|
5 | 8 |
|
6 | 9 | import aiohttp
|
| 10 | +import requests.auth |
| 11 | +from requests.utils import parse_dict_header |
7 | 12 |
|
8 | 13 | from . import DOCS_HOME
|
9 | 14 | from . import __version__
|
@@ -36,25 +41,77 @@ def _detect_faulty_requests(): # pragma: no cover
|
36 | 41 | del _detect_faulty_requests
|
37 | 42 |
|
38 | 43 |
|
| 44 | +class AuthMethod(ABC): |
| 45 | + def __init__(self, username, password): |
| 46 | + self.username = username |
| 47 | + self.password = password |
| 48 | + |
| 49 | + @abstractmethod |
| 50 | + def handle_401(self, response): |
| 51 | + raise NotImplementedError |
| 52 | + |
| 53 | + @abstractmethod |
| 54 | + def get_auth_header(self, method, url): |
| 55 | + raise NotImplementedError |
| 56 | + |
| 57 | + |
| 58 | +class BasicAuthMethod(AuthMethod): |
| 59 | + def handle_401(self, _response): |
| 60 | + pass |
| 61 | + |
| 62 | + def get_auth_header(self, _method, _url): |
| 63 | + auth_str = f"{self.username}:{self.password}" |
| 64 | + return "Basic " + b64encode(auth_str.encode('utf-8')).decode("utf-8") |
| 65 | + |
| 66 | + |
| 67 | +class DigestAuthMethod(AuthMethod): |
| 68 | + # make class var to 'cache' the state, which is more efficient because otherwise |
| 69 | + # each request would first require another 'initialization' request. |
| 70 | + _auth_helpers = {} |
| 71 | + |
| 72 | + def __init__(self, username, password): |
| 73 | + super().__init__(username, password) |
| 74 | + |
| 75 | + self._auth_helper = self._auth_helpers.get( |
| 76 | + (username, password), |
| 77 | + requests.auth.HTTPDigestAuth(username, password) |
| 78 | + ) |
| 79 | + self._auth_helpers[(username, password)] = self._auth_helper |
| 80 | + |
| 81 | + @property |
| 82 | + def auth_helper_vars(self): |
| 83 | + return self._auth_helper._thread_local |
| 84 | + |
| 85 | + def handle_401(self, response): |
| 86 | + s_auth = response.headers.get("www-authenticate", "") |
| 87 | + |
| 88 | + if "digest" in s_auth.lower(): |
| 89 | + # Original source: |
| 90 | + # https://github.com/psf/requests/blob/f12ccbef6d6b95564da8d22e280d28c39d53f0e9/src/requests/auth.py#L262-L263 |
| 91 | + pat = re.compile(r"digest ", flags=re.IGNORECASE) |
| 92 | + self.auth_helper_vars.chal = parse_dict_header(pat.sub("", s_auth, count=1)) |
| 93 | + |
| 94 | + def get_auth_header(self, method, url): |
| 95 | + self._auth_helper.init_per_thread_state() |
| 96 | + |
| 97 | + if not self.auth_helper_vars.chal: |
| 98 | + # Need to do init request first |
| 99 | + return '' |
| 100 | + |
| 101 | + return self._auth_helper.build_digest_header(method, url) |
| 102 | + |
| 103 | + |
39 | 104 | def prepare_auth(auth, username, password):
|
40 | 105 | if username and password:
|
41 | 106 | if auth == "basic" or auth is None:
|
42 |
| - return aiohttp.BasicAuth(username, password) |
| 107 | + return BasicAuthMethod(username, password) |
43 | 108 | elif auth == "digest":
|
44 |
| - from requests.auth import HTTPDigestAuth |
45 |
| - |
46 |
| - return HTTPDigestAuth(username, password) |
| 109 | + return DigestAuthMethod(username, password) |
47 | 110 | elif auth == "guess":
|
48 |
| - try: |
49 |
| - from requests_toolbelt.auth.guess import GuessAuth |
50 |
| - except ImportError: |
51 |
| - raise exceptions.UserError( |
52 |
| - "Your version of requests_toolbelt is too " |
53 |
| - "old for `guess` authentication. At least " |
54 |
| - "version 0.4.0 is required." |
55 |
| - ) |
56 |
| - else: |
57 |
| - return GuessAuth(username, password) |
| 111 | + raise exceptions.UserError(f"'Guess' authentication is not supported in this version of vdirsyncer. \n" |
| 112 | + f"Please explicitly specify either 'basic' or 'digest' auth instead. \n" |
| 113 | + f"See the following issue for more information: " |
| 114 | + f"https://github.com/pimutils/vdirsyncer/issues/1015") |
58 | 115 | else:
|
59 | 116 | raise exceptions.UserError(f"Unknown authentication method: {auth}")
|
60 | 117 | elif auth:
|
@@ -97,14 +154,17 @@ async def request(
|
97 | 154 | method,
|
98 | 155 | url,
|
99 | 156 | session,
|
| 157 | + auth, |
100 | 158 | latin1_fallback=True,
|
101 | 159 | **kwargs,
|
102 | 160 | ):
|
103 |
| - """Wrapper method for requests, to ease logging and mocking. |
| 161 | + """Wrapper method for requests, to ease logging and mocking as well as to |
| 162 | + support auth methods currently unsupported by aiohttp. |
104 | 163 |
|
105 |
| - Parameters should be the same as for ``aiohttp.request``, as well as: |
| 164 | + Parameters should be the same as for ``aiohttp.request``, except: |
106 | 165 |
|
107 | 166 | :param session: A requests session object to use.
|
| 167 | + :param auth: The HTTP ``AuthMethod`` to use for authentication. |
108 | 168 | :param verify_fingerprint: Optional. SHA256 of the expected server certificate.
|
109 | 169 | :param latin1_fallback: RFC-2616 specifies the default Content-Type of
|
110 | 170 | text/* to be latin1, which is not always correct, but exactly what
|
@@ -134,7 +194,21 @@ async def request(
|
134 | 194 | ssl_context.load_cert_chain(*cert)
|
135 | 195 | kwargs["ssl"] = ssl_context
|
136 | 196 |
|
137 |
| - response = await session.request(method, url, **kwargs) |
| 197 | + headers = kwargs.pop("headers", {}) |
| 198 | + num_401 = 0 |
| 199 | + while num_401 < 2: |
| 200 | + headers["Authorization"] = auth.get_auth_header(method, url) |
| 201 | + response = await session.request(method, url, headers=headers, **kwargs) |
| 202 | + |
| 203 | + if response.ok: |
| 204 | + break |
| 205 | + |
| 206 | + if response.status == 401: |
| 207 | + num_401 += 1 |
| 208 | + auth.handle_401(response) |
| 209 | + else: |
| 210 | + # some other error, will be handled later on |
| 211 | + break |
138 | 212 |
|
139 | 213 | # See https://github.com/kennethreitz/requests/issues/2042
|
140 | 214 | content_type = response.headers.get("Content-Type", "")
|
|
0 commit comments