Skip to content

Commit e6397eb

Browse files
authored
feat: upload images with keep-alive (#772)
* use requests.Session to keep http connections alive * skip fetch offset if it is uuid * parallel upload images in the same sessions * sanitize cache path which contains sensitive segments * retry for image uploading * handle keyboardinterrupted * add --num_upload_workers * log upload name * skip caching filehandles when dry run is enabled * send progress to image uploading * shutdown when sys.version >=3.13 * fix typing for override * fix types for python3.99999 * fix tests * fix invalid offset * print upload time
1 parent dc0ad06 commit e6397eb

File tree

10 files changed

+672
-498
lines changed

10 files changed

+672
-498
lines changed

mapillary_tools/api_v4.py

Lines changed: 45 additions & 285 deletions
Large diffs are not rendered by default.

mapillary_tools/authenticate.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import requests
1212

13-
from . import api_v4, config, constants, exceptions
13+
from . import api_v4, config, constants, exceptions, http
1414

1515

1616
LOG = logging.getLogger(__name__)
@@ -77,11 +77,11 @@ def authenticate(
7777
# TODO: print more user information
7878
if profile_name in all_user_items:
7979
LOG.info(
80-
'Profile "%s" updated: %s', profile_name, api_v4._sanitize(user_items)
80+
'Profile "%s" updated: %s', profile_name, http._sanitize(user_items)
8181
)
8282
else:
8383
LOG.info(
84-
'Profile "%s" created: %s', profile_name, api_v4._sanitize(user_items)
84+
'Profile "%s" created: %s', profile_name, http._sanitize(user_items)
8585
)
8686

8787

@@ -134,9 +134,8 @@ def fetch_user_items(
134134
)
135135

136136
if organization_key is not None:
137-
resp = api_v4.fetch_organization(
138-
user_items["user_upload_token"], organization_key
139-
)
137+
with api_v4.create_user_session(user_items["user_upload_token"]) as session:
138+
resp = api_v4.fetch_organization(session, organization_key)
140139
data = api_v4.jsonify_response(resp)
141140
LOG.info(
142141
f"Uploading to organization: {data.get('name')} (ID: {data.get('id')})"
@@ -173,16 +172,15 @@ def _verify_user_auth(user_items: config.UserItem) -> config.UserItem:
173172
if constants._AUTH_VERIFICATION_DISABLED:
174173
return user_items
175174

176-
try:
177-
resp = api_v4.fetch_user_or_me(
178-
user_access_token=user_items["user_upload_token"]
179-
)
180-
except requests.HTTPError as ex:
181-
if api_v4.is_auth_error(ex.response):
182-
message = api_v4.extract_auth_error_message(ex.response)
183-
raise exceptions.MapillaryUploadUnauthorizedError(message)
184-
else:
185-
raise ex
175+
with api_v4.create_user_session(user_items["user_upload_token"]) as session:
176+
try:
177+
resp = api_v4.fetch_user_or_me(session)
178+
except requests.HTTPError as ex:
179+
if api_v4.is_auth_error(ex.response):
180+
message = api_v4.extract_auth_error_message(ex.response)
181+
raise exceptions.MapillaryUploadUnauthorizedError(message)
182+
else:
183+
raise ex
186184

187185
data = api_v4.jsonify_response(resp)
188186

@@ -276,16 +274,17 @@ def _prompt_login(
276274
if user_password:
277275
break
278276

279-
try:
280-
resp = api_v4.get_upload_token(user_email, user_password)
281-
except requests.HTTPError as ex:
282-
if not _enabled:
283-
raise ex
277+
with api_v4.create_client_session() as session:
278+
try:
279+
resp = api_v4.get_upload_token(session, user_email, user_password)
280+
except requests.HTTPError as ex:
281+
if not _enabled:
282+
raise ex
284283

285-
if _is_login_retryable(ex):
286-
return _prompt_login()
284+
if _is_login_retryable(ex):
285+
return _prompt_login()
287286

288-
raise ex
287+
raise ex
289288

290289
data = api_v4.jsonify_response(resp)
291290

mapillary_tools/commands/upload.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ def add_common_upload_options(group):
2323
default=None,
2424
required=False,
2525
)
26+
group.add_argument(
27+
"--num_upload_workers",
28+
help="Number of concurrent upload workers for uploading images. [default: %(default)s]",
29+
default=constants.MAX_IMAGE_UPLOAD_WORKERS,
30+
type=int,
31+
required=False,
32+
)
2633
group.add_argument(
2734
"--reupload",
2835
help="Re-upload data that has already been uploaded.",

mapillary_tools/constants.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,18 @@ def _parse_scaled_integers(
154154
# The minimal upload speed is used to calculate the read timeout to avoid upload hanging:
155155
# timeout = upload_size / MIN_UPLOAD_SPEED
156156
MIN_UPLOAD_SPEED: int | None = _parse_filesize(
157-
os.getenv(_ENV_PREFIX + "MIN_UPLOAD_SPEED", "50K") # 50 KiB/s
157+
os.getenv(_ENV_PREFIX + "MIN_UPLOAD_SPEED", "50K") # 50 Kb/s
158158
)
159+
# Maximum number of parallel workers for uploading images within a single sequence.
160+
# NOTE: Sequences themselves are uploaded sequentially, not in parallel.
159161
MAX_IMAGE_UPLOAD_WORKERS: int = int(
160-
os.getenv(_ENV_PREFIX + "MAX_IMAGE_UPLOAD_WORKERS", 64)
162+
os.getenv(_ENV_PREFIX + "MAX_IMAGE_UPLOAD_WORKERS", 4)
161163
)
162164
# The chunk size in MB (see chunked transfer encoding https://en.wikipedia.org/wiki/Chunked_transfer_encoding)
163165
# for uploading data to MLY upload service.
164166
# Changing this size does not change the number of requests nor affect upload performance,
165167
# but it affects the responsiveness of the upload progress bar
166-
UPLOAD_CHUNK_SIZE_MB: float = float(os.getenv(_ENV_PREFIX + "UPLOAD_CHUNK_SIZE_MB", 1))
168+
UPLOAD_CHUNK_SIZE_MB: float = float(os.getenv(_ENV_PREFIX + "UPLOAD_CHUNK_SIZE_MB", 2))
167169
MAX_UPLOAD_RETRIES: int = int(os.getenv(_ENV_PREFIX + "MAX_UPLOAD_RETRIES", 200))
168170
MAPILLARY__ENABLE_UPLOAD_HISTORY_FOR_DRY_RUN: bool = _yes_or_no(
169171
os.getenv("MAPILLARY__ENABLE_UPLOAD_HISTORY_FOR_DRY_RUN", "NO")

mapillary_tools/http.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
import ssl
6+
import sys
7+
import typing as T
8+
from json import dumps
9+
10+
if sys.version_info >= (3, 12):
11+
from typing import override
12+
else:
13+
from typing_extensions import override
14+
15+
import requests
16+
from requests.adapters import HTTPAdapter
17+
18+
19+
LOG = logging.getLogger(__name__)
20+
21+
22+
class HTTPSystemCertsAdapter(HTTPAdapter):
23+
"""
24+
This adapter uses the system's certificate store instead of the certifi module.
25+
26+
The implementation is based on the project https://pypi.org/project/pip-system-certs/,
27+
which has a system-wide effect.
28+
"""
29+
30+
def init_poolmanager(self, *args, **kwargs):
31+
ssl_context = ssl.create_default_context()
32+
ssl_context.load_default_certs()
33+
kwargs["ssl_context"] = ssl_context
34+
35+
super().init_poolmanager(*args, **kwargs)
36+
37+
def cert_verify(self, *args, **kwargs):
38+
super().cert_verify(*args, **kwargs)
39+
40+
# By default Python requests uses the ca_certs from the certifi module
41+
# But we want to use the certificate store instead.
42+
# By clearing the ca_certs variable we force it to fall back on that behaviour (handled in urllib3)
43+
if "conn" in kwargs:
44+
conn = kwargs["conn"]
45+
else:
46+
conn = args[0]
47+
48+
conn.ca_certs = None
49+
50+
51+
class Session(requests.Session):
52+
# NOTE: This is a global flag that affects all Session instances
53+
USE_SYSTEM_CERTS: T.ClassVar[bool] = False
54+
# Instance variables
55+
disable_logging_request: bool = False
56+
disable_logging_response: bool = False
57+
# Avoid mounting twice
58+
_mounted: bool = False
59+
60+
@override
61+
def request(self, method: str | bytes, url: str | bytes, *args, **kwargs):
62+
self._log_debug_request(method, url, *args, **kwargs)
63+
64+
if Session.USE_SYSTEM_CERTS:
65+
if not self._mounted:
66+
self.mount("https://", HTTPSystemCertsAdapter())
67+
self._mounted = True
68+
resp = super().request(method, url, *args, **kwargs)
69+
else:
70+
try:
71+
resp = super().request(method, url, *args, **kwargs)
72+
except requests.exceptions.SSLError as ex:
73+
if "SSLCertVerificationError" not in str(ex):
74+
raise ex
75+
Session.USE_SYSTEM_CERTS = True
76+
# HTTPSConnectionPool(host='graph.mapillary.com', port=443): Max retries exceeded with url: /login (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1018)')))
77+
LOG.warning(
78+
"SSL error occurred, falling back to system SSL certificates: %s",
79+
ex,
80+
)
81+
return self.request(method, url, *args, **kwargs)
82+
83+
self._log_debug_response(resp)
84+
85+
return resp
86+
87+
def _log_debug_request(self, method: str | bytes, url: str | bytes, **kwargs):
88+
if self.disable_logging_request:
89+
return
90+
91+
if logging.getLogger().getEffectiveLevel() <= logging.DEBUG:
92+
return
93+
94+
if isinstance(method, str) and isinstance(url, str):
95+
msg = f"HTTP {method} {url}"
96+
else:
97+
msg = f"HTTP {method!r} {url!r}"
98+
99+
if Session.USE_SYSTEM_CERTS:
100+
msg += " (w/sys_certs)"
101+
102+
json = kwargs.get("json")
103+
if json is not None:
104+
t = _truncate(dumps(_sanitize(json)))
105+
msg += f" JSON={t}"
106+
107+
params = kwargs.get("params")
108+
if params is not None:
109+
msg += f" PARAMS={_sanitize(params)}"
110+
111+
headers = kwargs.get("headers")
112+
if headers is not None:
113+
msg += f" HEADERS={_sanitize(headers)}"
114+
115+
timeout = kwargs.get("timeout")
116+
if timeout is not None:
117+
msg += f" TIMEOUT={timeout}"
118+
119+
msg = msg.replace("\n", "\\n")
120+
121+
LOG.debug(msg)
122+
123+
def _log_debug_response(self, resp: requests.Response):
124+
if self.disable_logging_response:
125+
return
126+
127+
if logging.getLogger().getEffectiveLevel() <= logging.DEBUG:
128+
return
129+
130+
elapsed = resp.elapsed.total_seconds() * 1000 # Convert to milliseconds
131+
msg = f"HTTP {resp.status_code} {resp.reason} ({elapsed:.0f} ms): {str(_truncate_response_content(resp))}"
132+
133+
LOG.debug(msg)
134+
135+
136+
def readable_http_error(ex: requests.HTTPError) -> str:
137+
return readable_http_response(ex.response)
138+
139+
140+
def readable_http_response(resp: requests.Response) -> str:
141+
return f"{resp.request.method} {resp.url} => {resp.status_code} {resp.reason}: {str(_truncate_response_content(resp))}"
142+
143+
144+
@T.overload
145+
def _truncate(s: bytes, limit: int = 256) -> bytes | str: ...
146+
147+
148+
@T.overload
149+
def _truncate(s: str, limit: int = 256) -> str: ...
150+
151+
152+
def _truncate(s, limit=256):
153+
if limit < len(s):
154+
if isinstance(s, bytes):
155+
try:
156+
s = s.decode("utf-8")
157+
except UnicodeDecodeError:
158+
pass
159+
remaining = len(s) - limit
160+
if isinstance(s, bytes):
161+
return s[:limit] + f"...({remaining} bytes truncated)".encode("utf-8")
162+
else:
163+
return str(s[:limit]) + f"...({remaining} chars truncated)"
164+
else:
165+
return s
166+
167+
168+
def _sanitize(headers: T.Mapping[T.Any, T.Any]) -> T.Mapping[T.Any, T.Any]:
169+
new_headers = {}
170+
171+
for k, v in headers.items():
172+
if k.lower() in [
173+
"authorization",
174+
"cookie",
175+
"x-fb-access-token",
176+
"access-token",
177+
"access_token",
178+
"password",
179+
"user_upload_token",
180+
]:
181+
new_headers[k] = "[REDACTED]"
182+
else:
183+
if isinstance(v, (str, bytes)):
184+
new_headers[k] = T.cast(T.Any, _truncate(v))
185+
else:
186+
new_headers[k] = v
187+
188+
return new_headers
189+
190+
191+
def _truncate_response_content(resp: requests.Response) -> str | bytes:
192+
try:
193+
json_data = resp.json()
194+
except requests.JSONDecodeError:
195+
if resp.content is not None:
196+
data = _truncate(resp.content)
197+
else:
198+
data = ""
199+
else:
200+
if isinstance(json_data, dict):
201+
data = _truncate(dumps(_sanitize(json_data)))
202+
else:
203+
data = _truncate(str(json_data))
204+
205+
if isinstance(data, bytes):
206+
return data.replace(b"\n", b"\\n")
207+
208+
elif isinstance(data, str):
209+
return data.replace("\n", "\\n")
210+
211+
return data

0 commit comments

Comments
 (0)