Skip to content

Commit 170e223

Browse files
committed
Bazel: allow LFS rules to use cached downloads without internet
If the cache is prefilled, LFS rules were still trying to query LFS urls. Now the strategy is to first try to fetch the files from the repository cache (which is possible by providing an empty url list and `allow_fail` to `repository_ctx.download`), and only run the LFS protocol if that fails. Technically this is possible by enhancing `git_lfs_probe.py` with a `--hash-only` flag. This is also an optimization where no uneeded access is done (including the slightly slow SSH call) if the repository cache is warm.
1 parent a50584c commit 170e223

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

misc/bazel/internal/git_lfs_probe.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
"""
44
Probe lfs files.
5-
For each source file provided as output, this will print:
5+
For each source file provided as input, this will print:
66
* "local", if the source file is not an LFS pointer
77
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
8+
If --hash-only is provided, the transient URL will not be fetched and printed
89
"""
910

1011
import sys
@@ -19,6 +20,13 @@
1920
import base64
2021
from dataclasses import dataclass
2122
from typing import Dict
23+
import argparse
24+
25+
def options():
26+
p = argparse.ArgumentParser(description=__doc__)
27+
p.add_argument("--hash-only", action="store_true")
28+
p.add_argument("sources", type=pathlib.Path, nargs="+")
29+
return p.parse_args()
2230

2331

2432
@dataclass
@@ -30,7 +38,8 @@ def update_headers(self, d: Dict[str, str]):
3038
self.headers.update((k.capitalize(), v) for k, v in d.items())
3139

3240

33-
sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]]
41+
opts = options()
42+
sources = [p.resolve() for p in opts.sources]
3443
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
3544
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip()
3645

@@ -84,11 +93,15 @@ def get_endpoint():
8493
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
8594
def get_locations(objects):
8695
ret = ["local" for _ in objects]
87-
endpoint = get_endpoint()
8896
indexes = [i for i, o in enumerate(objects) if o]
8997
if not indexes:
9098
# all objects are local, do not send an empty request as that would be an error
9199
return ret
100+
if opts.hash_only:
101+
for i in indexes:
102+
ret[i] = objects[i]["oid"]
103+
return ret
104+
endpoint = get_endpoint()
92105
data = {
93106
"operation": "download",
94107
"transfers": ["basic"],

misc/bazel/lfs.bzl

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,44 @@
11
def lfs_smudge(repository_ctx, srcs, extract = False, stripPrefix = None):
2-
for src in srcs:
3-
repository_ctx.watch(src)
4-
script = Label("//misc/bazel/internal:git_lfs_probe.py")
52
python = repository_ctx.which("python3") or repository_ctx.which("python")
63
if not python:
74
fail("Neither python3 nor python executables found")
8-
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs]))
9-
res = repository_ctx.execute([python, script] + srcs, quiet = True)
10-
if res.return_code != 0:
11-
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
12-
promises = []
13-
for src, loc in zip(srcs, res.stdout.splitlines()):
14-
if loc == "local":
15-
if extract:
16-
repository_ctx.report_progress("extracting local %s" % src.basename)
17-
repository_ctx.extract(src, stripPrefix = stripPrefix)
18-
else:
19-
repository_ctx.report_progress("symlinking local %s" % src.basename)
20-
repository_ctx.symlink(src, src.basename)
5+
script = Label("//misc/bazel/internal:git_lfs_probe.py")
6+
7+
def probe(srcs, hash_only = False):
8+
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs]))
9+
cmd = [python, script]
10+
if hash_only:
11+
cmd.append("--hash-only")
12+
cmd.extend(srcs)
13+
res = repository_ctx.execute(cmd, quiet = True)
14+
if res.return_code != 0:
15+
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
16+
return res.stdout.splitlines()
17+
18+
for src in srcs:
19+
repository_ctx.watch(src)
20+
infos = probe(srcs, hash_only = True)
21+
remote = []
22+
for src, info in zip(srcs, infos):
23+
if info == "local":
24+
repository_ctx.report_progress("symlinking local %s" % src.basename)
25+
repository_ctx.symlink(src, src.basename)
2126
else:
22-
sha256, _, url = loc.partition(" ")
23-
if extract:
24-
# we can't use skylib's `paths.split_extension`, as that only gets the last extension, so `.tar.gz`
25-
# or similar wouldn't work
26-
# it doesn't matter if file is something like some.name.zip and possible_extension == "name.zip",
27-
# download_and_extract will just append ".name.zip" its internal temporary name, so extraction works
28-
possible_extension = ".".join(src.basename.rsplit(".", 2)[-2:])
29-
repository_ctx.report_progress("downloading and extracting remote %s" % src.basename)
30-
repository_ctx.download_and_extract(url, sha256 = sha256, stripPrefix = stripPrefix, type = possible_extension)
31-
else:
27+
repository_ctx.report_progress("trying cache for remote %s" % src.basename)
28+
res = repository_ctx.download([], src.basename, sha256 = info, allow_fail = True)
29+
if not res.success:
30+
remote.append(src)
31+
if remote:
32+
infos = probe(remote)
33+
for src, info in zip(remote, infos):
34+
sha256, _, url = info.partition(" ")
3235
repository_ctx.report_progress("downloading remote %s" % src.basename)
3336
repository_ctx.download(url, src.basename, sha256 = sha256)
37+
if extract:
38+
for src in srcs:
39+
repository_ctx.report_progress("extracting %s" % src.basename)
40+
repository_ctx.extract(src.basename, stripPrefix = stripPrefix)
41+
repository_ctx.delete(src.basename)
3442

3543
def _download_and_extract_lfs(repository_ctx):
3644
attr = repository_ctx.attr

0 commit comments

Comments
 (0)