Skip to content

Commit 9789438

Browse files
[iree.build] Make the fetch_http action more robust. (#19330)
* Downloads to a staging file and then atomically renames into place, avoiding potential for partial downloads. * Reports completion percent as part of the console updates. * Persists metadata for the source URL and will refetch if changed. * Fixes an error handling test for the onnx mnist_builder that missed the prior update. More sophistication is possible but this brings it up to min-viable from a usability perspective. Signed-off-by: Stella Laurenzo <[email protected]>
1 parent d182e57 commit 9789438

File tree

5 files changed

+201
-6
lines changed

5 files changed

+201
-6
lines changed

compiler/bindings/python/iree/build/executor.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import concurrent.futures
1010
import enum
11+
import json
1112
import math
1213
import multiprocessing
1314
import os
@@ -128,6 +129,7 @@ def __init__(self, output_dir: Path, stderr: IO, reporter: ProgressReporter):
128129
self.failed_deps: set["BuildDependency"] = set()
129130
self.stderr = stderr
130131
self.reporter = reporter
132+
self.metadata_lock = threading.RLock()
131133
BuildContext("", self)
132134

133135
def check_path_not_exists(self, path: str, for_entity):
@@ -160,6 +162,7 @@ def get_file(self, path: str) -> "BuildFile":
160162
return existing
161163

162164
def write_status(self, message: str):
165+
self.reporter.reset_display()
163166
print(message, file=self.stderr)
164167

165168
def get_root(self, namespace: FileNamespace) -> Path:
@@ -294,6 +297,9 @@ def finish(self):
294297
self.future.set_result(self)
295298

296299

300+
BuildFileMetadata = dict[str, str | int | bool | float]
301+
302+
297303
class BuildFile(BuildDependency):
298304
"""Generated file in the build tree."""
299305

@@ -322,6 +328,35 @@ def get_fs_path(self) -> Path:
322328
path.parent.mkdir(parents=True, exist_ok=True)
323329
return path
324330

331+
def access_metadata(
332+
self,
333+
mutation_callback: Callable[[BuildFileMetadata], bool] | None = None,
334+
) -> BuildFileMetadata:
335+
"""Accesses persistent metadata about the build file.
336+
337+
This is intended for the storage of small amounts of metadata relevant to the
338+
build system for performing up-to-date checks and the like.
339+
340+
If a `mutation_callback=` is provided, then any modifications it makes will be
341+
persisted prior to returning. Using a callback in this fashion holds a lock
342+
and avoids data races. If the callback returns True, it is persisted.
343+
"""
344+
with self.executor.metadata_lock:
345+
metadata = _load_metadata(self.executor)
346+
path_metadata = metadata.get("paths")
347+
if path_metadata is None:
348+
path_metadata = {}
349+
metadata["paths"] = path_metadata
350+
file_key = f"{self.namespace}/{self.path}"
351+
file_metadata = path_metadata.get(file_key)
352+
if file_metadata is None:
353+
file_metadata = {}
354+
path_metadata[file_key] = file_metadata
355+
if mutation_callback:
356+
if mutation_callback(file_metadata):
357+
_save_metadata(self.executor, metadata)
358+
return file_metadata
359+
325360
def __repr__(self):
326361
return f"BuildFile[{self.namespace}]({self.path})"
327362

@@ -658,3 +693,20 @@ def invoke():
658693

659694
# Type aliases.
660695
BuildFileLike = BuildFile | str
696+
697+
# Private utilities.
698+
_METADATA_FILENAME = ".metadata.json"
699+
700+
701+
def _load_metadata(executor: Executor) -> dict:
702+
path = executor.output_dir / _METADATA_FILENAME
703+
if not path.exists():
704+
return {}
705+
with open(path, "rb") as f:
706+
return json.load(f)
707+
708+
709+
def _save_metadata(executor: Executor, metadata: dict):
710+
path = executor.output_dir / _METADATA_FILENAME
711+
with open(path, "wt") as f:
712+
json.dump(metadata, f, sort_keys=True, indent=2)

compiler/bindings/python/iree/build/net_actions.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import urllib.error
88
import urllib.request
99

10-
from iree.build.executor import BuildAction, BuildContext, BuildFile
10+
from iree.build.executor import BuildAction, BuildContext, BuildFile, BuildFileMetadata
1111

1212
__all__ = [
1313
"fetch_http",
@@ -29,11 +29,49 @@ def __init__(self, url: str, output_file: BuildFile, **kwargs):
2929
super().__init__(**kwargs)
3030
self.url = url
3131
self.output_file = output_file
32+
self.original_desc = self.desc
3233

3334
def _invoke(self):
35+
# Determine whether metadata indicates that fetch is needed.
3436
path = self.output_file.get_fs_path()
37+
needs_fetch = False
38+
existing_metadata = self.output_file.access_metadata()
39+
existing_url = existing_metadata.get("fetch_http.url")
40+
if existing_url != self.url:
41+
needs_fetch = True
42+
43+
# Always fetch if empty or absent.
44+
if not path.exists() or path.stat().st_size == 0:
45+
needs_fetch = True
46+
47+
# Bail if already obtained.
48+
if not needs_fetch:
49+
return
50+
51+
# Download to a staging file.
52+
stage_path = path.with_name(f".{path.name}.download")
3553
self.executor.write_status(f"Fetching URL: {self.url} -> {path}")
54+
55+
def reporthook(received_blocks: int, block_size: int, total_size: int):
56+
received_size = received_blocks * block_size
57+
if total_size == 0:
58+
self.desc = f"{self.original_desc} ({received_size} bytes received)"
59+
else:
60+
complete_percent = round(100 * received_size / total_size)
61+
self.desc = f"{self.original_desc} ({complete_percent}% complete)"
62+
3663
try:
37-
urllib.request.urlretrieve(self.url, str(path))
64+
urllib.request.urlretrieve(self.url, str(stage_path), reporthook=reporthook)
3865
except urllib.error.HTTPError as e:
3966
raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None
67+
finally:
68+
self.desc = self.original_desc
69+
70+
# Commit the download.
71+
def commit(metadata: BuildFileMetadata) -> bool:
72+
metadata["fetch_http.url"] = self.url
73+
path.unlink(missing_ok=True)
74+
stage_path.rename(path)
75+
return True
76+
77+
self.output_file.access_metadata(commit)

compiler/bindings/python/test/build_api/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@ iree_py_test(
2020
SRCS
2121
"basic_test.py"
2222
)
23+
24+
iree_py_test(
25+
NAME
26+
net_test
27+
SRCS
28+
"net_test.py"
29+
)

compiler/bindings/python/test/build_api/mnist_builder_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,7 @@ def testActionCLArg(self):
9090
mod = load_build_module(THIS_DIR / "mnist_builder.py")
9191
out_file = io.StringIO()
9292
err_file = io.StringIO()
93-
with self.assertRaisesRegex(
94-
IOError,
95-
re.escape("Failed to fetch URL 'https://github.com/iree-org/doesnotexist'"),
96-
):
93+
with self.assertRaises(SystemExit):
9794
iree_build_main(
9895
mod,
9996
args=[
@@ -104,6 +101,7 @@ def testActionCLArg(self):
104101
stdout=out_file,
105102
stderr=err_file,
106103
)
104+
self.assertIn("ERROR:", err_file.getvalue())
107105

108106
def testBuildNonDefaultSubTarget(self):
109107
mod = load_build_module(THIS_DIR / "mnist_builder.py")
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import io
8+
import os
9+
from pathlib import Path
10+
import tempfile
11+
import unittest
12+
13+
from iree.build import *
14+
from iree.build.executor import BuildContext
15+
from iree.build.test_actions import ExecuteOutOfProcessThunkAction
16+
17+
18+
TEST_URL = None
19+
TEST_URL_1 = "https://huggingface.co/google-bert/bert-base-cased/resolve/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer.json"
20+
TEST_URL_2 = "https://huggingface.co/google-bert/bert-base-cased/resolve/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer_config.json"
21+
22+
23+
@entrypoint
24+
def tokenizer_via_http():
25+
return fetch_http(
26+
name="tokenizer.json",
27+
url=TEST_URL,
28+
)
29+
30+
31+
class BasicTest(unittest.TestCase):
32+
def setUp(self):
33+
self._temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
34+
self._temp_dir.__enter__()
35+
self.output_path = Path(self._temp_dir.name)
36+
37+
def tearDown(self) -> None:
38+
self._temp_dir.__exit__(None, None, None)
39+
40+
def test_fetch_http(self):
41+
# This just does a sanity check that rich console mode does not crash. Actual
42+
# behavior can really only be completely verified visually.
43+
out = None
44+
err = None
45+
global TEST_URL
46+
path = self.output_path / "genfiles" / "tokenizer_via_http" / "tokenizer.json"
47+
48+
def run():
49+
nonlocal out
50+
nonlocal err
51+
try:
52+
out_io = io.StringIO()
53+
err_io = io.StringIO()
54+
iree_build_main(
55+
args=[
56+
"tokenizer_via_http",
57+
"--output-dir",
58+
str(self.output_path),
59+
"--test-force-console",
60+
],
61+
stderr=err_io,
62+
stdout=out_io,
63+
)
64+
finally:
65+
out = out_io.getvalue()
66+
err = err_io.getvalue()
67+
print(f"::test_fetch_http err: {err!r}")
68+
print(f"::test_fetch_http out: {out!r}")
69+
70+
def assertExists():
71+
self.assertTrue(path.exists(), msg=f"Path {path} exists")
72+
73+
# First run should fetch.
74+
TEST_URL = TEST_URL_1
75+
run()
76+
self.assertIn("Fetching URL: https://", err)
77+
assertExists()
78+
79+
# Second run should not fetch.
80+
TEST_URL = TEST_URL_1
81+
run()
82+
self.assertNotIn("Fetching URL: https://", err)
83+
assertExists()
84+
85+
# Fetching a different URL should download again.
86+
TEST_URL = TEST_URL_2
87+
run()
88+
self.assertIn("Fetching URL: https://", err)
89+
assertExists()
90+
91+
# Removing the file should fetch again.
92+
TEST_URL = TEST_URL_2
93+
path.unlink()
94+
run()
95+
self.assertIn("Fetching URL: https://", err)
96+
assertExists()
97+
98+
99+
if __name__ == "__main__":
100+
unittest.main()

0 commit comments

Comments
 (0)