Skip to content

Commit c576a11

Browse files
author
Paolo Tranquilli
committed
Bazel: make git_lfs_probe.py try all available endpoints
1 parent b63bd2a commit c576a11

File tree

1 file changed

+111
-85
lines changed

1 file changed

+111
-85
lines changed

misc/bazel/internal/git_lfs_probe.py

Lines changed: 111 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
88
If --hash-only is provided, the transient URL will not be fetched and printed
99
"""
10-
10+
import dataclasses
1111
import sys
1212
import pathlib
1313
import subprocess
1414
import os
1515
import shutil
1616
import json
17+
import typing
1718
import urllib.request
1819
from urllib.parse import urlparse
1920
import re
2021
import base64
2122
from dataclasses import dataclass
22-
from typing import Dict
2323
import argparse
2424

2525

@@ -32,11 +32,13 @@ def options():
3232

3333
@dataclass
3434
class Endpoint:
35+
name: str
3536
href: str
36-
headers: Dict[str, str]
37+
ssh: str | None = None
38+
headers: typing.Dict[str, str] = dataclasses.field(default_factory=dict)
3739

38-
def update_headers(self, d: Dict[str, str]):
39-
self.headers.update((k.capitalize(), v) for k, v in d.items())
40+
def update_headers(self, d: typing.Iterable[tuple[str, str]]):
41+
self.headers.update((k.capitalize(), v) for k, v in d)
4042

4143

4244
opts = options()
@@ -47,88 +49,107 @@ def update_headers(self, d: Dict[str, str]):
4749
).strip()
4850

4951

50-
def get_env(s, sep="="):
51-
ret = {}
52+
def get_env(s: str, sep: str = "=") -> typing.Iterable[tuple[str, str]]:
5253
for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M):
53-
ret.setdefault(*m.groups())
54-
return ret
54+
yield m.groups()
5555

5656

5757
def git(*args, **kwargs):
58-
return subprocess.run(
58+
proc = subprocess.run(
5959
("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs
60-
).stdout.strip()
60+
)
61+
return proc.stdout.strip() if proc.returncode == 0 else None
62+
6163

64+
endpoint_re = re.compile(r"^Endpoint(?: \((.*)\))?$")
6265

63-
def get_endpoint():
64-
lfs_env_items = iter(
65-
get_env(
66-
subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
67-
).items()
66+
67+
def get_endpoint_addresses() -> typing.Iterable[Endpoint]:
68+
"""Get all lfs endpoints, including SSH if present"""
69+
lfs_env_items = get_env(
70+
subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
6871
)
69-
endpoint = next(v for k, v in lfs_env_items if k.startswith("Endpoint"))
70-
endpoint, _, _ = endpoint.partition(" ")
71-
# only take the ssh endpoint if it follows directly after the first endpoint we found
72-
# in a situation like
73-
# Endpoint (a)=...
74-
# Endpoint (b)=...
75-
# SSH=...
76-
# we want to ignore the SSH endpoint, as it's not linked to the default (a) endpoint
77-
following_key, following_value = next(lfs_env_items, (None, None))
78-
ssh_endpoint = following_value if following_key == " SSH" else None
79-
80-
endpoint = Endpoint(
81-
endpoint,
82-
{
72+
current_endpoint = None
73+
for k, v in lfs_env_items:
74+
m = endpoint_re.match(k)
75+
if m:
76+
if current_endpoint:
77+
yield current_endpoint
78+
href, _, _ = v.partition(" ")
79+
current_endpoint = Endpoint(name=m[1] or "default", href=href)
80+
elif k == " SSH" and current_endpoint:
81+
current_endpoint.ssh = v
82+
if current_endpoint:
83+
yield current_endpoint
84+
85+
86+
def get_endpoints() -> typing.Iterable[Endpoint]:
87+
for endpoint in get_endpoint_addresses():
88+
endpoint.headers = {
8389
"Content-Type": "application/vnd.git-lfs+json",
8490
"Accept": "application/vnd.git-lfs+json",
85-
},
86-
)
87-
if ssh_endpoint:
88-
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
89-
server, _, path = ssh_endpoint.partition(":")
90-
ssh_command = shutil.which(
91-
os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))
92-
)
93-
assert ssh_command, "no ssh command found"
94-
resp = json.loads(
95-
subprocess.check_output(
96-
[
97-
ssh_command,
98-
"-oStrictHostKeyChecking=accept-new",
99-
server,
100-
"git-lfs-authenticate",
101-
path,
102-
"download",
103-
]
91+
}
92+
if endpoint.ssh:
93+
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
94+
server, _, path = endpoint.ssh.partition(":")
95+
ssh_command = shutil.which(
96+
os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))
10497
)
105-
)
106-
endpoint.href = resp.get("href", endpoint)
107-
endpoint.update_headers(resp.get("header", {}))
108-
url = urlparse(endpoint.href)
109-
# this is how actions/checkout persist credentials
110-
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
111-
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader")
112-
endpoint.update_headers(get_env(auth, sep=": "))
113-
if os.environ.get("GITHUB_TOKEN"):
114-
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
115-
if "Authorization" not in endpoint.headers:
116-
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
117-
# see https://git-scm.com/docs/git-credential
118-
credentials = get_env(
119-
git(
98+
assert ssh_command, "no ssh command found"
99+
cmd = [
100+
ssh_command,
101+
"-oStrictHostKeyChecking=accept-new",
102+
server,
103+
"git-lfs-authenticate",
104+
path,
105+
"download",
106+
]
107+
try:
108+
res = subprocess.run(cmd, stdout=subprocess.PIPE, timeout=15)
109+
except subprocess.TimeoutExpired:
110+
print(
111+
f"WARNING: ssh timed out when connecting to {server}, ignoring {endpoint.name} endpoint",
112+
file=sys.stderr,
113+
)
114+
continue
115+
if res.returncode != 0:
116+
print(
117+
f"WARNING: ssh failed when connecting to {server}, ignoring {endpoint.name} endpoint",
118+
file=sys.stderr,
119+
)
120+
continue
121+
ssh_resp = json.loads(res.stdout)
122+
endpoint.href = ssh_resp.get("href", endpoint)
123+
endpoint.update_headers(ssh_resp.get("header", {}).items())
124+
url = urlparse(endpoint.href)
125+
# this is how actions/checkout persist credentials
126+
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
127+
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") or ""
128+
endpoint.update_headers(get_env(auth, sep=": "))
129+
if os.environ.get("GITHUB_TOKEN"):
130+
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
131+
if "Authorization" not in endpoint.headers:
132+
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
133+
# see https://git-scm.com/docs/git-credential
134+
credentials = git(
120135
"credential",
121136
"fill",
122137
check=True,
123138
# drop leading / from url.path
124139
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n",
125140
)
126-
)
127-
auth = base64.b64encode(
128-
f'{credentials["username"]}:{credentials["password"]}'.encode()
129-
).decode("ascii")
130-
endpoint.headers["Authorization"] = f"Basic {auth}"
131-
return endpoint
141+
if credentials is None:
142+
print(
143+
f"WARNING: no authorization method found, ignoring {data.name} endpoint",
144+
file=sys.stderr,
145+
)
146+
continue
147+
credentials = dict(get_env(credentials))
148+
auth = base64.b64encode(
149+
f'{credentials["username"]}:{credentials["password"]}'.encode()
150+
).decode("ascii")
151+
endpoint.headers["Authorization"] = f"Basic {auth}"
152+
yield endpoint
132153

133154

134155
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
@@ -142,26 +163,31 @@ def get_locations(objects):
142163
for i in indexes:
143164
ret[i] = objects[i]["oid"]
144165
return ret
145-
endpoint = get_endpoint()
146166
data = {
147167
"operation": "download",
148168
"transfers": ["basic"],
149169
"objects": [objects[i] for i in indexes],
150170
"hash_algo": "sha256",
151171
}
152-
req = urllib.request.Request(
153-
f"{endpoint.href}/objects/batch",
154-
headers=endpoint.headers,
155-
data=json.dumps(data).encode("ascii"),
156-
)
157-
with urllib.request.urlopen(req) as resp:
158-
data = json.load(resp)
159-
assert len(data["objects"]) == len(
160-
indexes
161-
), f"received {len(data)} objects, expected {len(indexes)}"
162-
for i, resp in zip(indexes, data["objects"]):
163-
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
164-
return ret
172+
for endpoint in get_endpoints():
173+
req = urllib.request.Request(
174+
f"{endpoint.href}/objects/batch",
175+
headers=endpoint.headers,
176+
data=json.dumps(data).encode("ascii"),
177+
)
178+
try:
179+
with urllib.request.urlopen(req) as resp:
180+
data = json.load(resp)
181+
except urllib.request.HTTPError as e:
182+
print(f"WARNING: encountered HTTPError {e}, ignoring endpoint {e.name}")
183+
continue
184+
assert len(data["objects"]) == len(
185+
indexes
186+
), f"received {len(data)} objects, expected {len(indexes)}"
187+
for i, resp in zip(indexes, data["objects"]):
188+
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
189+
return ret
190+
raise Exception(f"no valid endpoint found")
165191

166192

167193
def get_lfs_object(path):
@@ -171,7 +197,7 @@ def get_lfs_object(path):
171197
sha256 = size = None
172198
if lfs_header != actual_header:
173199
return None
174-
data = get_env(fileobj.read().decode("ascii"), sep=" ")
200+
data = dict(get_env(fileobj.read().decode("ascii"), sep=" "))
175201
assert data["oid"].startswith("sha256:"), f"unknown oid type: {data['oid']}"
176202
_, _, sha256 = data["oid"].partition(":")
177203
size = int(data["size"])

0 commit comments

Comments
 (0)