Skip to content

Commit ecdf623

Browse files
committed
Bazel: clean up git_lfs_probe.py
1 parent 96d69ca commit ecdf623

File tree

1 file changed

+56
-56
lines changed

1 file changed

+56
-56
lines changed

misc/bazel/internal/git_lfs_probe.py

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,74 +16,78 @@
1616
import urllib.request
1717
from urllib.parse import urlparse
1818
import re
19+
import base64
20+
from dataclasses import dataclass
21+
22+
23+
@dataclass
24+
class Endpoint:
25+
href: str
26+
headers: dict[str, str]
27+
28+
def update_headers(self, d: dict[str, str]):
29+
self.headers.update((k.capitalize(), v) for k, v in d.items())
30+
1931

2032
sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]]
2133
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
2234
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip()
2335

2436

37+
def get_env(s, sep="="):
38+
ret = {}
39+
for m in re.finditer(fr'(.*?){sep}(.*)', s, re.M):
40+
ret.setdefault(*m.groups())
41+
return ret
42+
43+
44+
def git(*args, **kwargs):
45+
return subprocess.run(("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs).stdout.strip()
46+
47+
2548
def get_endpoint():
26-
lfs_env = subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
27-
endpoint = ssh_server = ssh_path = None
28-
endpoint_re = re.compile(r'Endpoint(?: \(\S+\))?=(\S+)')
29-
ssh_re = re.compile(r'\s*SSH=(\S*):(.*)')
30-
credentials_re = re.compile(r'^password=(.*)$', re.M)
31-
for line in lfs_env.splitlines():
32-
m = endpoint_re.match(line)
33-
if m:
34-
if endpoint is None:
35-
endpoint = m[1]
36-
else:
37-
break
38-
m = ssh_re.match(line)
39-
if m:
40-
ssh_server, ssh_path = m.groups()
41-
break
42-
assert endpoint, f"no Endpoint= line found in git lfs env:\n{lfs_env}"
43-
headers = {
49+
lfs_env = get_env(subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir))
50+
endpoint = next(v for k, v in lfs_env.items() if k.startswith('Endpoint'))
51+
endpoint, _, _ = endpoint.partition(' ')
52+
ssh_endpoint = lfs_env.get(" SSH")
53+
endpoint = Endpoint(endpoint, {
4454
"Content-Type": "application/vnd.git-lfs+json",
4555
"Accept": "application/vnd.git-lfs+json",
46-
}
47-
if ssh_server:
56+
})
57+
if ssh_endpoint:
58+
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
59+
server, _, path = ssh_endpoint.partition(":")
4860
ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")))
4961
assert ssh_command, "no ssh command found"
50-
with subprocess.Popen([ssh_command, ssh_server, "git-lfs-authenticate", ssh_path, "download"],
51-
stdout=subprocess.PIPE) as ssh:
52-
resp = json.load(ssh.stdout)
53-
assert ssh.wait() == 0, "ssh command failed"
54-
endpoint = resp.get("href", endpoint)
55-
for k, v in resp.get("header", {}).items():
56-
headers[k.capitalize()] = v
57-
url = urlparse(endpoint)
62+
resp = json.loads(subprocess.check_output([ssh_command, server, "git-lfs-authenticate", path, "download"]))
63+
endpoint.href = resp.get("href", endpoint)
64+
endpoint.update_headers(resp.get("header", {}))
65+
url = urlparse(endpoint.href)
5866
# this is how actions/checkout persist credentials
5967
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
60-
auth = subprocess.run(["git", "config", f"http.{url.scheme}://{url.netloc}/.extraheader"], text=True,
61-
stdout=subprocess.PIPE, cwd=source_dir).stdout.strip()
62-
for l in auth.splitlines():
63-
k, _, v = l.partition(": ")
64-
headers[k.capitalize()] = v
68+
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader")
69+
endpoint.update_headers(get_env(auth, sep=": "))
6570
if "GITHUB_TOKEN" in os.environ:
66-
headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
67-
if "Authorization" not in headers:
68-
credentials = subprocess.run(["git", "credential", "fill"], cwd=source_dir, stdout=subprocess.PIPE, text=True,
69-
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n",
70-
check=True).stdout
71-
m = credentials_re.search(credentials)
72-
if m:
73-
headers["Authorization"] = f"token {m[1]}"
74-
else:
75-
print(f"WARNING: no auth credentials found for {endpoint}")
76-
return endpoint, headers
71+
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
72+
if "Authorization" not in endpoint.headers:
73+
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
74+
# see https://git-scm.com/docs/git-credential
75+
credentials = get_env(git("credential", "fill", check=True,
76+
# drop leading / from url.path
77+
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n"))
78+
auth = base64.b64encode(f'{credentials["username"]}:{credentials["password"]}'.encode()).decode('ascii')
79+
endpoint.headers["Authorization"] = f"Basic {auth}"
80+
return endpoint
7781

7882

7983
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
8084
def get_locations(objects):
81-
href, headers = get_endpoint()
85+
endpoint = get_endpoint()
8286
indexes = [i for i, o in enumerate(objects) if o]
8387
ret = ["local" for _ in objects]
8488
req = urllib.request.Request(
85-
f"{href}/objects/batch",
86-
headers=headers,
89+
f"{endpoint.href}/objects/batch",
90+
headers=endpoint.headers,
8791
data=json.dumps({
8892
"operation": "download",
8993
"transfers": ["basic"],
@@ -93,7 +97,7 @@ def get_locations(objects):
9397
)
9498
with urllib.request.urlopen(req) as resp:
9599
data = json.load(resp)
96-
assert len(data["objects"]) == len(indexes), data
100+
assert len(data["objects"]) == len(indexes), f"received {len(data)} objects, expected {len(indexes)}"
97101
for i, resp in zip(indexes, data["objects"]):
98102
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
99103
return ret
@@ -106,14 +110,10 @@ def get_lfs_object(path):
106110
sha256 = size = None
107111
if lfs_header != actual_header:
108112
return None
109-
for line in fileobj:
110-
line = line.decode('ascii').strip()
111-
if line.startswith("oid sha256:"):
112-
sha256 = line[len("oid sha256:"):]
113-
elif line.startswith("size "):
114-
size = int(line[len("size "):])
115-
if not (sha256 and line):
116-
raise Exception("malformed pointer file")
113+
data = get_env(fileobj.read().decode('ascii'), sep=' ')
114+
assert data['oid'].startswith('sha256:'), f"unknown oid type: {data['oid']}"
115+
_, _, sha256 = data['oid'].partition(':')
116+
size = int(data['size'])
117117
return {"oid": sha256, "size": size}
118118

119119

0 commit comments

Comments
 (0)