Skip to content

Commit ce54e00

Browse files
authored
Allow user to provide scalac source jar for "ast-plus" dependency tracking feature (#1493)
* Allow scalac srcjar to scala_config * Let user handle jar selection * Add test cases for error handling and verify jar is used
1 parent 4e903af commit ce54e00

File tree

6 files changed

+234
-16
lines changed

6 files changed

+234
-16
lines changed

dt_patches/dt_patch_test.sh

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ run_test_local() {
2424

2525
run_in_test_repo() {
2626
local test_command=$1
27+
local test_repo=$2
2728

28-
cd "${dir}"/test_dt_patches
29+
cd "${dir}/${test_repo}" || exit 1
2930
${test_command}
3031
RESPONSE_CODE=$?
3132
cd ../..
@@ -36,9 +37,33 @@ run_in_test_repo() {
3637
test_compiler_patch() {
3738
local SCALA_VERSION="$1"
3839

39-
run_in_test_repo "bazel build //... --repo_env=SCALA_VERSION=${SCALA_VERSION} //..."
40+
run_in_test_repo "bazel build //... --repo_env=SCALA_VERSION=${SCALA_VERSION} //..." "test_dt_patches"
4041
}
4142

43+
test_compiler_srcjar() {
44+
set -o pipefail
45+
local SCALA_VERSION="$1"
46+
47+
run_in_test_repo "bazel build //... --repo_env=SCALA_VERSION=${SCALA_VERSION} //..." "test_dt_patches_user_srcjar" 2>&1 | (! grep "canonical reproducible")
48+
}
49+
50+
test_compiler_srcjar_nonhermetic() {
51+
set -o pipefail
52+
local SCALA_VERSION="$1"
53+
54+
run_in_test_repo "bazel build //... --repo_env=SCALA_VERSION=${SCALA_VERSION} //..." "test_dt_patches_user_srcjar" 2>&1 | grep "canonical reproducible"
55+
}
56+
57+
test_compiler_srcjar_error() {
58+
local SCALA_VERSION="$1"
59+
local EXPECTED_ERROR="scala_compiler_srcjar invalid"
60+
61+
run_in_test_repo "bazel build //... --repo_env=SCALA_VERSION=${SCALA_VERSION} //..." "test_dt_patches_user_srcjar" 2>&1 | grep "$EXPECTED_ERROR"
62+
}
63+
64+
run_test_local test_compiler_patch 2.12.1
65+
66+
4267
#run_test_local test_compiler_patch 2.11.0
4368
#run_test_local test_compiler_patch 2.11.1
4469
#run_test_local test_compiler_patch 2.11.2
@@ -80,3 +105,14 @@ run_test_local test_compiler_patch 2.13.5
80105
run_test_local test_compiler_patch 2.13.6
81106
run_test_local test_compiler_patch 2.13.7
82107
run_test_local test_compiler_patch 2.13.8
108+
109+
run_test_local test_compiler_srcjar_error 2.12.11
110+
run_test_local test_compiler_srcjar_error 2.12.12
111+
run_test_local test_compiler_srcjar_error 2.12.13
112+
# These tests are semi-stateful, if two tests are run sequentially with the
113+
# same Scala version, the DEBUG message about a canonical reproducible form
114+
# that we grep for will only be outputted the first time (on Bazel >= 6).
115+
run_test_local test_compiler_srcjar 2.12.14
116+
run_test_local test_compiler_srcjar 2.12.15
117+
run_test_local test_compiler_srcjar 2.12.16
118+
run_test_local test_compiler_srcjar_nonhermetic 2.12.17
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
load(
2+
"@io_bazel_rules_scala//scala:scala.bzl",
3+
"setup_scala_toolchain",
4+
)
5+
6+
SCALA_LIBS = [
7+
"@scala_library",
8+
"@scala_reflect",
9+
]
10+
11+
setup_scala_toolchain(
12+
name = "dt_scala_toolchain",
13+
scala_compile_classpath = ["@scala_compiler"] + SCALA_LIBS,
14+
scala_library_classpath = SCALA_LIBS,
15+
scala_macro_classpath = SCALA_LIBS,
16+
)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
workspace(name = "test_dt_patches")
2+
3+
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_jar")
4+
5+
http_archive(
6+
name = "bazel_skylib",
7+
sha256 = "b8a1527901774180afc798aeb28c4634bdccf19c4d98e7bdd1ce79d1fe9aaad7",
8+
urls = [
9+
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.4.1/bazel-skylib-1.4.1.tar.gz",
10+
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.4.1/bazel-skylib-1.4.1.tar.gz",
11+
],
12+
)
13+
14+
local_repository(
15+
name = "io_bazel_rules_scala",
16+
path = "../..",
17+
)
18+
19+
load("@io_bazel_rules_scala//:scala_config.bzl", "scala_config")
20+
21+
scala_config(enable_compiler_dependency_tracking = True)
22+
23+
load(
24+
"@io_bazel_rules_scala//scala:scala.bzl",
25+
"rules_scala_setup",
26+
"rules_scala_toolchain_deps_repositories",
27+
)
28+
load(
29+
"@io_bazel_rules_scala//scala:scala_cross_version.bzl",
30+
"default_maven_server_urls",
31+
)
32+
load(
33+
"@io_bazel_rules_scala//scala:scala_maven_import_external.bzl",
34+
"scala_maven_import_external",
35+
)
36+
load("@io_bazel_rules_scala_config//:config.bzl", "SCALA_VERSION")
37+
38+
http_jar(
39+
name = "scala_compiler_srcjar",
40+
sha256 = "95c217cc87ee846b39990e0a9c273824a384dffbac57df84d466f866df4a91ea",
41+
url = "https://repo1.maven.org/maven2/org/scala-lang/scala-compiler/2.12.16/scala-compiler-2.12.16-sources.jar",
42+
)
43+
44+
scala_maven_import_external(
45+
name = "scala_library",
46+
artifact = "org.scala-lang:scala-library:%s" % SCALA_VERSION,
47+
licenses = ["notice"],
48+
server_urls = default_maven_server_urls(),
49+
)
50+
51+
scala_maven_import_external(
52+
name = "scala_compiler",
53+
artifact = "org.scala-lang:scala-compiler:%s" % SCALA_VERSION,
54+
licenses = ["notice"],
55+
server_urls = default_maven_server_urls(),
56+
)
57+
58+
scala_maven_import_external(
59+
name = "scala_reflect",
60+
artifact = "org.scala-lang:scala-reflect:%s" % SCALA_VERSION,
61+
licenses = ["notice"],
62+
server_urls = default_maven_server_urls(),
63+
)
64+
65+
srcjars_by_version = {
66+
# Invalid
67+
"2.12.11": [],
68+
# Invalid
69+
"2.12.12": {
70+
"lable": "foo",
71+
},
72+
# Invalid
73+
"2.12.13": {
74+
"url": "https://repo1.maven.org/maven2/org/scala-lang/scala-compiler/2.12.13/scala-compiler-2.12.13-sources.jar",
75+
"label": "foo",
76+
},
77+
"2.12.14": {
78+
"urls": ["https://repo1.maven.org/maven2/org/scala-lang/scala-compiler/2.12.14/scala-compiler-2.12.14-sources.jar"],
79+
"integrity": "sha384-yKJTudaHM2dA+VM//elLxhEfOmyCYRHzbLlQcf5jlrR+G5FEW+fBW/b794mQLMOX",
80+
},
81+
"2.12.15": {
82+
"url": "https://repo1.maven.org/maven2/org/scala-lang/scala-compiler/2.12.15/scala-compiler-2.12.15-sources.jar",
83+
"sha256": "65f783f1fbef7de661224f607ac07ca03c5d19acfdb7f2234ff8def1e79b5cd8",
84+
},
85+
"2.12.16": {
86+
"label": "@scala_compiler_srcjar//jar:downloaded.jar",
87+
},
88+
"2.12.17": {
89+
"url": "https://repo1.maven.org/maven2/org/scala-lang/scala-compiler/2.12.17/scala-compiler-2.12.17-sources.jar?foo",
90+
},
91+
}
92+
93+
rules_scala_setup(scala_compiler_srcjar = srcjars_by_version[SCALA_VERSION])
94+
95+
rules_scala_toolchain_deps_repositories(
96+
fetch_sources = True,
97+
validate_scala_version = False,
98+
)
99+
100+
register_toolchains(":dt_scala_toolchain")
101+
102+
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
103+
104+
rules_proto_dependencies()
105+
106+
rules_proto_toolchains()
107+
108+
load("@io_bazel_rules_scala//scala:toolchains.bzl", "scala_register_toolchains")
109+
110+
scala_register_toolchains()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
load("@io_bazel_rules_scala//scala:scala.bzl", "scala_library")
2+
3+
scala_library(
4+
name = "dummy",
5+
srcs = ["Dummy.scala"],
6+
visibility = ["//visibility:public"],
7+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package dummy
2+
3+
class Dummy

scala/private/macros/scala_repositories.bzl

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,33 @@ load(
1111
"SCALA_VERSION",
1212
)
1313

14-
def dt_patched_compiler_setup():
14+
def _dt_patched_compiler_impl(rctx):
15+
# Need to give the file a .zip extension so rctx.extract knows what type of archive it is
16+
rctx.symlink(rctx.attr.srcjar, "file.zip")
17+
rctx.extract(archive = "file.zip")
18+
rctx.patch(rctx.attr.patch)
19+
rctx.file("BUILD", content = rctx.attr.build_file_content)
20+
21+
dt_patched_compiler = repository_rule(
22+
attrs = {
23+
"patch": attr.label(),
24+
"srcjar": attr.label(),
25+
"build_file_content": attr.string(),
26+
},
27+
implementation = _dt_patched_compiler_impl,
28+
)
29+
30+
def _validate_scalac_srcjar(srcjar):
31+
if type(srcjar) != "dict":
32+
return False
33+
oneof = ["url", "urls", "label"]
34+
count = 0
35+
for key in oneof:
36+
if key in srcjar:
37+
count += 1
38+
return count == 1
39+
40+
def dt_patched_compiler_setup(scala_compiler_srcjar = None):
1541
patch = "@io_bazel_rules_scala//dt_patches:dt_compiler_%s.patch" % SCALA_MAJOR_VERSION
1642

1743
minor_version = int(SCALA_MINOR_VERSION)
@@ -22,20 +48,40 @@ def dt_patched_compiler_setup():
2248
elif minor_version <= 11:
2349
patch = "@io_bazel_rules_scala//dt_patches:dt_compiler_%s.8.patch" % SCALA_MAJOR_VERSION
2450

25-
http_archive(
26-
name = "scala_compiler_source",
27-
build_file_content = "\n".join([
28-
"package(default_visibility = [\"//visibility:public\"])",
29-
"filegroup(",
30-
" name = \"src\",",
31-
" srcs=[\"scala/tools/nsc/symtab/SymbolLoaders.scala\"],",
32-
")",
33-
]),
34-
patches = [patch],
35-
url = "https://repo1.maven.org/maven2/org/scala-lang/scala-compiler/%s/scala-compiler-%s-sources.jar" % (SCALA_VERSION, SCALA_VERSION),
51+
build_file_content = "\n".join([
52+
"package(default_visibility = [\"//visibility:public\"])",
53+
"filegroup(",
54+
" name = \"src\",",
55+
" srcs=[\"scala/tools/nsc/symtab/SymbolLoaders.scala\"],",
56+
")",
57+
])
58+
default_scalac_srcjar = {
59+
"url": "https://repo1.maven.org/maven2/org/scala-lang/scala-compiler/%s/scala-compiler-%s-sources.jar" % (SCALA_VERSION, SCALA_VERSION),
60+
}
61+
srcjar = scala_compiler_srcjar if scala_compiler_srcjar != None else default_scalac_srcjar
62+
_validate_scalac_srcjar(srcjar) or fail(
63+
("scala_compiler_srcjar invalid, must be a dict with exactly one of \"label\", \"url\"" +
64+
" or \"urls\" keys, got: ") + repr(srcjar),
3665
)
66+
if "label" in srcjar:
67+
dt_patched_compiler(
68+
name = "scala_compiler_source",
69+
build_file_content = build_file_content,
70+
patch = patch,
71+
srcjar = srcjar["label"],
72+
)
73+
else:
74+
http_archive(
75+
name = "scala_compiler_source",
76+
build_file_content = build_file_content,
77+
patches = [patch],
78+
url = srcjar.get("url"),
79+
urls = srcjar.get("urls"),
80+
sha256 = srcjar.get("sha256"),
81+
integrity = srcjar.get("integrity"),
82+
)
3783

38-
def rules_scala_setup():
84+
def rules_scala_setup(scala_compiler_srcjar = None):
3985
if not native.existing_rule("bazel_skylib"):
4086
http_archive(
4187
name = "bazel_skylib",
@@ -74,7 +120,7 @@ def rules_scala_setup():
74120
],
75121
)
76122

77-
dt_patched_compiler_setup()
123+
dt_patched_compiler_setup(scala_compiler_srcjar)
78124

79125
ARTIFACT_IDS = [
80126
"io_bazel_rules_scala_scala_library",

0 commit comments

Comments
 (0)