Skip to content

Commit 875d1d3

Browse files
authored
Merge pull request #17172 from github/redsun82/bazel-lfs
Bazel: make `git_lfs_probe.py` try all available endpoints
2 parents de40dfd + e451f2b commit 875d1d3

File tree

1 file changed

+132
-77
lines changed

1 file changed

+132
-77
lines changed

misc/bazel/internal/git_lfs_probe.py

Lines changed: 132 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
88
If --hash-only is provided, the transient URL will not be fetched and printed
99
"""
10-
10+
import dataclasses
1111
import sys
1212
import pathlib
1313
import subprocess
1414
import os
1515
import shutil
1616
import json
17+
import typing
1718
import urllib.request
1819
from urllib.parse import urlparse
1920
import re
2021
import base64
2122
from dataclasses import dataclass
22-
from typing import Dict
2323
import argparse
2424

2525

@@ -32,76 +32,124 @@ def options():
3232

3333
@dataclass
3434
class Endpoint:
35+
name: str
3536
href: str
36-
headers: Dict[str, str]
37+
ssh: typing.Optional[str] = None
38+
headers: typing.Dict[str, str] = dataclasses.field(default_factory=dict)
3739

38-
def update_headers(self, d: Dict[str, str]):
39-
self.headers.update((k.capitalize(), v) for k, v in d.items())
40+
def update_headers(self, d: typing.Iterable[typing.Tuple[str, str]]):
41+
self.headers.update((k.capitalize(), v) for k, v in d)
4042

4143

4244
opts = options()
4345
sources = [p.resolve() for p in opts.sources]
4446
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
45-
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip()
47+
source_dir = subprocess.check_output(
48+
["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True
49+
).strip()
4650

4751

48-
def get_env(s, sep="="):
49-
ret = {}
50-
for m in re.finditer(fr'(.*?){sep}(.*)', s, re.M):
51-
ret.setdefault(*m.groups())
52-
return ret
52+
def get_env(s: str, sep: str = "=") -> typing.Iterable[typing.Tuple[str, str]]:
53+
for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M):
54+
yield m.groups()
5355

5456

5557
def git(*args, **kwargs):
56-
return subprocess.run(("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs).stdout.strip()
57-
58-
59-
def get_endpoint():
60-
lfs_env_items = iter(get_env(subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)).items())
61-
endpoint = next(v for k, v in lfs_env_items if k.startswith('Endpoint'))
62-
endpoint, _, _ = endpoint.partition(' ')
63-
# only take the ssh endpoint if it follows directly after the first endpoint we found
64-
# in a situation like
65-
# Endpoint (a)=...
66-
# Endpoint (b)=...
67-
# SSH=...
68-
# we want to ignore the SSH endpoint, as it's not linked to the default (a) endpoint
69-
following_key, following_value = next(lfs_env_items, (None, None))
70-
ssh_endpoint = following_value if following_key == " SSH" else None
71-
72-
endpoint = Endpoint(endpoint, {
73-
"Content-Type": "application/vnd.git-lfs+json",
74-
"Accept": "application/vnd.git-lfs+json",
75-
})
76-
if ssh_endpoint:
77-
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
78-
server, _, path = ssh_endpoint.partition(":")
79-
ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")))
80-
assert ssh_command, "no ssh command found"
81-
resp = json.loads(subprocess.check_output([ssh_command,
82-
"-oStrictHostKeyChecking=accept-new",
83-
server,
84-
"git-lfs-authenticate",
85-
path,
86-
"download"]))
87-
endpoint.href = resp.get("href", endpoint)
88-
endpoint.update_headers(resp.get("header", {}))
89-
url = urlparse(endpoint.href)
90-
# this is how actions/checkout persist credentials
91-
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
92-
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader")
93-
endpoint.update_headers(get_env(auth, sep=": "))
94-
if os.environ.get("GITHUB_TOKEN"):
95-
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
96-
if "Authorization" not in endpoint.headers:
97-
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
98-
# see https://git-scm.com/docs/git-credential
99-
credentials = get_env(git("credential", "fill", check=True,
100-
# drop leading / from url.path
101-
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n"))
102-
auth = base64.b64encode(f'{credentials["username"]}:{credentials["password"]}'.encode()).decode('ascii')
103-
endpoint.headers["Authorization"] = f"Basic {auth}"
104-
return endpoint
58+
proc = subprocess.run(
59+
("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs
60+
)
61+
return proc.stdout.strip() if proc.returncode == 0 else None
62+
63+
64+
endpoint_re = re.compile(r"^Endpoint(?: \((.*)\))?$")
65+
66+
67+
def get_endpoint_addresses() -> typing.Iterable[Endpoint]:
68+
"""Get all lfs endpoints, including SSH if present"""
69+
lfs_env_items = get_env(
70+
subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
71+
)
72+
current_endpoint = None
73+
for k, v in lfs_env_items:
74+
m = endpoint_re.match(k)
75+
if m:
76+
if current_endpoint:
77+
yield current_endpoint
78+
href, _, _ = v.partition(" ")
79+
current_endpoint = Endpoint(name=m[1] or "default", href=href)
80+
elif k == " SSH" and current_endpoint:
81+
current_endpoint.ssh = v
82+
if current_endpoint:
83+
yield current_endpoint
84+
85+
86+
def get_endpoints() -> typing.Iterable[Endpoint]:
87+
for endpoint in get_endpoint_addresses():
88+
endpoint.headers = {
89+
"Content-Type": "application/vnd.git-lfs+json",
90+
"Accept": "application/vnd.git-lfs+json",
91+
}
92+
if endpoint.ssh:
93+
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
94+
server, _, path = endpoint.ssh.partition(":")
95+
ssh_command = shutil.which(
96+
os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))
97+
)
98+
assert ssh_command, "no ssh command found"
99+
cmd = [
100+
ssh_command,
101+
"-oStrictHostKeyChecking=accept-new",
102+
server,
103+
"git-lfs-authenticate",
104+
path,
105+
"download",
106+
]
107+
try:
108+
res = subprocess.run(cmd, stdout=subprocess.PIPE, timeout=15)
109+
except subprocess.TimeoutExpired:
110+
print(
111+
f"WARNING: ssh timed out when connecting to {server}, ignoring {endpoint.name} endpoint",
112+
file=sys.stderr,
113+
)
114+
continue
115+
if res.returncode != 0:
116+
print(
117+
f"WARNING: ssh failed when connecting to {server}, ignoring {endpoint.name} endpoint",
118+
file=sys.stderr,
119+
)
120+
continue
121+
ssh_resp = json.loads(res.stdout)
122+
endpoint.href = ssh_resp.get("href", endpoint)
123+
endpoint.update_headers(ssh_resp.get("header", {}).items())
124+
url = urlparse(endpoint.href)
125+
# this is how actions/checkout persist credentials
126+
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
127+
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") or ""
128+
endpoint.update_headers(get_env(auth, sep=": "))
129+
if os.environ.get("GITHUB_TOKEN"):
130+
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
131+
if "Authorization" not in endpoint.headers:
132+
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
133+
# see https://git-scm.com/docs/git-credential
134+
credentials = git(
135+
"credential",
136+
"fill",
137+
check=True,
138+
# drop leading / from url.path
139+
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n",
140+
)
141+
if credentials is None:
142+
print(
143+
f"WARNING: no authorization method found, ignoring {data.name} endpoint",
144+
file=sys.stderr,
145+
)
146+
continue
147+
credentials = dict(get_env(credentials))
148+
auth = base64.b64encode(
149+
f'{credentials["username"]}:{credentials["password"]}'.encode()
150+
).decode("ascii")
151+
endpoint.headers["Authorization"] = f"Basic {auth}"
152+
yield endpoint
105153

106154

107155
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
@@ -115,37 +163,44 @@ def get_locations(objects):
115163
for i in indexes:
116164
ret[i] = objects[i]["oid"]
117165
return ret
118-
endpoint = get_endpoint()
119166
data = {
120167
"operation": "download",
121168
"transfers": ["basic"],
122169
"objects": [objects[i] for i in indexes],
123170
"hash_algo": "sha256",
124171
}
125-
req = urllib.request.Request(
126-
f"{endpoint.href}/objects/batch",
127-
headers=endpoint.headers,
128-
data=json.dumps(data).encode("ascii"),
129-
)
130-
with urllib.request.urlopen(req) as resp:
131-
data = json.load(resp)
132-
assert len(data["objects"]) == len(indexes), f"received {len(data)} objects, expected {len(indexes)}"
133-
for i, resp in zip(indexes, data["objects"]):
134-
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
135-
return ret
172+
for endpoint in get_endpoints():
173+
req = urllib.request.Request(
174+
f"{endpoint.href}/objects/batch",
175+
headers=endpoint.headers,
176+
data=json.dumps(data).encode("ascii"),
177+
)
178+
try:
179+
with urllib.request.urlopen(req) as resp:
180+
data = json.load(resp)
181+
except urllib.request.HTTPError as e:
182+
print(f"WARNING: encountered HTTPError {e}, ignoring endpoint {e.name}")
183+
continue
184+
assert len(data["objects"]) == len(
185+
indexes
186+
), f"received {len(data)} objects, expected {len(indexes)}"
187+
for i, resp in zip(indexes, data["objects"]):
188+
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
189+
return ret
190+
raise Exception(f"no valid endpoint found")
136191

137192

138193
def get_lfs_object(path):
139-
with open(path, 'rb') as fileobj:
194+
with open(path, "rb") as fileobj:
140195
lfs_header = "version https://git-lfs.github.com/spec".encode()
141196
actual_header = fileobj.read(len(lfs_header))
142197
sha256 = size = None
143198
if lfs_header != actual_header:
144199
return None
145-
data = get_env(fileobj.read().decode('ascii'), sep=' ')
146-
assert data['oid'].startswith('sha256:'), f"unknown oid type: {data['oid']}"
147-
_, _, sha256 = data['oid'].partition(':')
148-
size = int(data['size'])
200+
data = dict(get_env(fileobj.read().decode("ascii"), sep=" "))
201+
assert data["oid"].startswith("sha256:"), f"unknown oid type: {data['oid']}"
202+
_, _, sha256 = data["oid"].partition(":")
203+
size = int(data["size"])
149204
return {"oid": sha256, "size": size}
150205

151206

0 commit comments

Comments
 (0)