Skip to content

Commit 611b866

Browse files
malmelooWhyNotHugo
authored andcommitted
Implement digest auth
1 parent 8550475 commit 611b866

File tree

1 file changed

+91
-17
lines changed

1 file changed

+91
-17
lines changed

vdirsyncer/http.py

Lines changed: 91 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from __future__ import annotations
22

33
import logging
4+
import re
5+
from abc import ABC, abstractmethod
6+
from base64 import b64encode
47
from ssl import create_default_context
58

69
import aiohttp
10+
import requests.auth
11+
from requests.utils import parse_dict_header
712

813
from . import DOCS_HOME
914
from . import __version__
@@ -36,25 +41,77 @@ def _detect_faulty_requests(): # pragma: no cover
3641
del _detect_faulty_requests
3742

3843

44+
class AuthMethod(ABC):
45+
def __init__(self, username, password):
46+
self.username = username
47+
self.password = password
48+
49+
@abstractmethod
50+
def handle_401(self, response):
51+
raise NotImplementedError
52+
53+
@abstractmethod
54+
def get_auth_header(self, method, url):
55+
raise NotImplementedError
56+
57+
58+
class BasicAuthMethod(AuthMethod):
59+
def handle_401(self, _response):
60+
pass
61+
62+
def get_auth_header(self, _method, _url):
63+
auth_str = f"{self.username}:{self.password}"
64+
return "Basic " + b64encode(auth_str.encode('utf-8')).decode("utf-8")
65+
66+
67+
class DigestAuthMethod(AuthMethod):
68+
# make class var to 'cache' the state, which is more efficient because otherwise
69+
# each request would first require another 'initialization' request.
70+
_auth_helpers = {}
71+
72+
def __init__(self, username, password):
73+
super().__init__(username, password)
74+
75+
self._auth_helper = self._auth_helpers.get(
76+
(username, password),
77+
requests.auth.HTTPDigestAuth(username, password)
78+
)
79+
self._auth_helpers[(username, password)] = self._auth_helper
80+
81+
@property
82+
def auth_helper_vars(self):
83+
return self._auth_helper._thread_local
84+
85+
def handle_401(self, response):
86+
s_auth = response.headers.get("www-authenticate", "")
87+
88+
if "digest" in s_auth.lower():
89+
# Original source:
90+
# https://github.com/psf/requests/blob/f12ccbef6d6b95564da8d22e280d28c39d53f0e9/src/requests/auth.py#L262-L263
91+
pat = re.compile(r"digest ", flags=re.IGNORECASE)
92+
self.auth_helper_vars.chal = parse_dict_header(pat.sub("", s_auth, count=1))
93+
94+
def get_auth_header(self, method, url):
95+
self._auth_helper.init_per_thread_state()
96+
97+
if not self.auth_helper_vars.chal:
98+
# Need to do init request first
99+
return ''
100+
101+
return self._auth_helper.build_digest_header(method, url)
102+
103+
39104
def prepare_auth(auth, username, password):
40105
if username and password:
41106
if auth == "basic" or auth is None:
42-
return aiohttp.BasicAuth(username, password)
107+
return BasicAuthMethod(username, password)
43108
elif auth == "digest":
44-
from requests.auth import HTTPDigestAuth
45-
46-
return HTTPDigestAuth(username, password)
109+
return DigestAuthMethod(username, password)
47110
elif auth == "guess":
48-
try:
49-
from requests_toolbelt.auth.guess import GuessAuth
50-
except ImportError:
51-
raise exceptions.UserError(
52-
"Your version of requests_toolbelt is too "
53-
"old for `guess` authentication. At least "
54-
"version 0.4.0 is required."
55-
)
56-
else:
57-
return GuessAuth(username, password)
111+
raise exceptions.UserError(f"'Guess' authentication is not supported in this version of vdirsyncer. \n"
112+
f"Please explicitly specify either 'basic' or 'digest' auth instead. \n"
113+
f"See the following issue for more information: "
114+
f"https://github.com/pimutils/vdirsyncer/issues/1015")
58115
else:
59116
raise exceptions.UserError(f"Unknown authentication method: {auth}")
60117
elif auth:
@@ -97,14 +154,17 @@ async def request(
97154
method,
98155
url,
99156
session,
157+
auth,
100158
latin1_fallback=True,
101159
**kwargs,
102160
):
103-
"""Wrapper method for requests, to ease logging and mocking.
161+
"""Wrapper method for requests, to ease logging and mocking as well as to
162+
support auth methods currently unsupported by aiohttp.
104163
105-
Parameters should be the same as for ``aiohttp.request``, as well as:
164+
Parameters should be the same as for ``aiohttp.request``, except:
106165
107166
:param session: A requests session object to use.
167+
:param auth: The HTTP ``AuthMethod`` to use for authentication.
108168
:param verify_fingerprint: Optional. SHA256 of the expected server certificate.
109169
:param latin1_fallback: RFC-2616 specifies the default Content-Type of
110170
text/* to be latin1, which is not always correct, but exactly what
@@ -134,7 +194,21 @@ async def request(
134194
ssl_context.load_cert_chain(*cert)
135195
kwargs["ssl"] = ssl_context
136196

137-
response = await session.request(method, url, **kwargs)
197+
headers = kwargs.pop("headers", {})
198+
num_401 = 0
199+
while num_401 < 2:
200+
headers["Authorization"] = auth.get_auth_header(method, url)
201+
response = await session.request(method, url, headers=headers, **kwargs)
202+
203+
if response.ok:
204+
break
205+
206+
if response.status == 401:
207+
num_401 += 1
208+
auth.handle_401(response)
209+
else:
210+
# some other error, will be handled later on
211+
break
138212

139213
# See https://github.com/kennethreitz/requests/issues/2042
140214
content_type = response.headers.get("Content-Type", "")

0 commit comments

Comments
 (0)