Skip to content

Commit 5816f08

Browse files
finn-ballFinn Ball
authored andcommitted
Harden multi-version CUDA toolchain selection on top of rebased #422
- Make redist component repo names unique per redist tag to avoid collisions when multiple `cuda.redist_json` entries resolve to the same version - Deduplicate generated version selectors and namespace version config_settings per component to prevent `config_setting` conflicts - Add platform-aware redist selection in repo rules and align deliverable tool labels/aliases (`nvcc`, `nvlink`, `bin2c`, `fatbinary`, `ptxas`, `cicc`) - Improve clang CUDA path handling for deliverable layouts to avoid invalid `--cuda-path` resolution - Extend dummy/tool alias coverage and templates for missing components/files - Add multi-version integration test coverage and skip linux-redist tests on Windows runners; guard CI redist override to NVIDIA-source cases - Fix linting errors
1 parent b1b546d commit 5816f08

File tree

20 files changed

+264
-50
lines changed

20 files changed

+264
-50
lines changed

.github/workflows/build-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ jobs:
9898

9999
# Override cuda redist_json rule version for CI
100100
- run: echo "CUDA_REDIST_VERSION_OVERRIDE=${{ matrix.cases.cuda-version }}" >> $GITHUB_ENV
101-
if: ${{ !startsWith(matrix.cases.os, 'windows') }}
101+
if: ${{ !startsWith(matrix.cases.os, 'windows') && matrix.cases.source == 'nvidia' }}
102102
- run: echo "CUDA_REDIST_VERSION_OVERRIDE=${{ matrix.cases.cuda-version }}" >> $env:GITHUB_ENV
103-
if: ${{ startsWith(matrix.cases.os, 'windows') }}
103+
if: ${{ startsWith(matrix.cases.os, 'windows') && matrix.cases.source == 'nvidia' }}
104104

105105
# Use Bazel with version specified in .bazelversion
106106
- run: echo "USE_BAZEL_VERSION=$(cat .bazelversion)" >> $GITHUB_ENV

README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,46 @@ use_repo(cuda, "cuda")
3232

3333
`rules_cc` provides the C++ toolchain dependency for `rules_cuda`; in Bzlmod, the compatibility repository is handled by `rules_cc` itself.
3434

35+
#### Multi-version hermetic toolchains (Bzlmod)
36+
37+
To select CUDA versions and platforms at build time, define multiple `cuda.redist_json` entries and a single `cuda.toolkit`.
38+
Then use `@rules_cuda//cuda:version`, `@rules_cuda//cuda:exec_platform`, and `@rules_cuda//cuda:target_platform` flags.
39+
40+
```starlark
41+
cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
42+
43+
cuda.redist_json(
44+
name = "cuda_13_0_2",
45+
version = "13.0.2",
46+
platforms = [
47+
"linux-x86_64",
48+
"linux-sbsa",
49+
],
50+
)
51+
cuda.redist_json(
52+
name = "cuda_13_0_0",
53+
version = "13.0.0",
54+
platforms = [
55+
"linux-x86_64",
56+
"linux-sbsa",
57+
],
58+
)
59+
60+
cuda.toolkit(name = "cuda")
61+
use_repo(cuda, "cuda")
62+
```
63+
64+
Example `.bazelrc` entries:
65+
66+
```
67+
build --@rules_cuda//cuda:exec_platform=linux-x86_64
68+
build --@rules_cuda//cuda:target_platform=linux-x86_64
69+
build --@rules_cuda//cuda:version=13.0.0
70+
```
71+
72+
Note: In Bzlmod, `platforms` is required for `cuda.redist_json` because module extensions don't have access to host OS/arch
73+
information, so the platforms must be declared explicitly.
74+
3575
<details>
3676
<summary>Traditional WORKSPACE approach</summary>
3777

cuda/BUILD.bazel

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ string_flag(
5454
config_setting(
5555
name = "target_platform_is_{}".format(platform.replace("-", "_")),
5656
flag_values = {":target_platform": platform},
57-
) for platform in SUPPORTED_PLATFORMS
57+
)
58+
for platform in SUPPORTED_PLATFORMS
5859
]
5960

6061
string_flag(
@@ -67,7 +68,8 @@ string_flag(
6768
config_setting(
6869
name = "exec_platform_is_{}".format(platform.replace("-", "_")),
6970
flag_values = {":exec_platform": platform},
70-
) for platform in SUPPORTED_PLATFORMS
71+
)
72+
for platform in SUPPORTED_PLATFORMS
7173
]
7274

7375
# Command line flag to specify the list of CUDA architectures to compile for.

cuda/defs.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Core rules for building CUDA projects.
33
"""
44

55
load("//cuda/private:defs.bzl", _requires_cuda = "requires_cuda")
6-
load("//cuda/private:errors.bzl", _unsupported_cuda_version = "unsupported_cuda_version", _unsupported_cuda_platform = "unsupported_cuda_platform")
6+
load("//cuda/private:errors.bzl", _unsupported_cuda_platform = "unsupported_cuda_platform", _unsupported_cuda_version = "unsupported_cuda_version")
77
load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary")
88
load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test")
99
load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows")

cuda/dummy/BUILD.bazel

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
load("@rules_cc//cc:defs.bzl", "cc_binary")
1+
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
22

33
package(default_visibility = ["//visibility:public"])
44

@@ -20,7 +20,10 @@ cc_binary(
2020
defines = ["TOOLNAME=nvlink"],
2121
)
2222

23-
exports_files(["link.stub", "libdevice.10.bc"])
23+
exports_files([
24+
"link.stub",
25+
"libdevice.10.bc",
26+
])
2427

2528
cc_binary(
2629
name = "bin2c",
@@ -34,15 +37,15 @@ cc_binary(
3437
defines = ["TOOLNAME=fatbinary"],
3538
)
3639

37-
cc_binary(
38-
name = "ptxas",
39-
srcs = ["dummy.cpp"],
40-
defines = ["TOOLNAME=ptxas"],
41-
)
42-
4340
# Empty cc_library that provides CcInfo for components not available in this CUDA version.
4441
cc_library(
4542
name = "dummy",
4643
srcs = [],
4744
hdrs = [],
4845
)
46+
47+
cc_binary(
48+
name = "ptxas",
49+
srcs = ["dummy.cpp"],
50+
defines = ["TOOLNAME=ptxas"],
51+
)

cuda/extensions.bzl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Entry point for extensions used by bzlmod."""
22

3+
load("//cuda:platform_alias_extension.bzl", "platform_alias_repo")
34
load("//cuda/private:redist_json_helper.bzl", "redist_json_helper")
45
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit")
5-
load("//cuda:platform_alias_extension.bzl", "platform_alias_repo")
66

77
cuda_component_tag = tag_class(attrs = {
88
"name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"),
@@ -76,6 +76,10 @@ cuda_toolkit_tag = tag_class(attrs = {
7676
"nvcc_version": attr.string(
7777
doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.",
7878
),
79+
"redist_json_name": attr.string(
80+
doc = "Name of the redist_json tag whose components this toolkit should use. " +
81+
"If omitted and exactly one redist_json is declared, it is used automatically.",
82+
),
7983
})
8084

8185
def _find_modules(module_ctx):
@@ -109,7 +113,8 @@ def _redist_json_impl(module_ctx, attr):
109113
mapping[spec["component_name"]] = repo_name
110114

111115
component_attr = {key: value for key, value in spec.items()}
112-
component_attr["name"] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_")
116+
component_repo_name = repo_name + "_" + attr.name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_")
117+
component_attr["name"] = component_repo_name
113118
cuda_component(**component_attr)
114119
platform_mapping[platform] = mapping
115120
return redist_ver, platform_mapping
@@ -128,6 +133,7 @@ def _impl(module_ctx):
128133
components = rules_cuda.tags.component
129134
redist_jsons = rules_cuda.tags.redist_json
130135
toolkits = rules_cuda.tags.toolkit
136+
131137
for component in components:
132138
cuda_component(**_module_tag_to_dict(component))
133139

@@ -141,7 +147,8 @@ def _impl(module_ctx):
141147
for redist_json in redist_jsons:
142148
components_mapping = {}
143149
redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json)
144-
redist_versions.append(redist_version)
150+
if redist_version not in redist_versions:
151+
redist_versions.append(redist_version)
145152
for platform in platform_mapping.keys():
146153
for component_name, repo_name in platform_mapping[platform].items():
147154
redist_components_mapping[component_name] = repo_name
@@ -151,7 +158,7 @@ def _impl(module_ctx):
151158
versioned_repos[component_name] = {}
152159
if platform not in versioned_repos[component_name]:
153160
versioned_repos[component_name][platform] = {}
154-
versioned_repos[component_name][platform][redist_version] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_")
161+
versioned_repos[component_name][platform][redist_version] = repo_name + "_" + redist_json.name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_")
155162

156163
for component_name in redist_components_mapping.keys():
157164
# Build dictionaries mapping versions to repo names for each platform.
@@ -169,19 +176,20 @@ def _impl(module_ctx):
169176
versions = redist_versions,
170177
)
171178
components_mapping[component_name] = "@" + redist_components_mapping[component_name]
179+
172180
registrations = {}
173181
for toolkit in toolkits:
174182
if toolkit.name in registrations.keys():
175183
if toolkit.toolkit_path == registrations[toolkit.name].toolkit_path:
176-
# No problem to register a matching toolkit twice
177184
continue
178-
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(toolkit.name, toolkit.toolkit_path, registrations[toolkit.name].toolkit_path))
185+
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(
186+
toolkit.name,
187+
toolkit.toolkit_path,
188+
registrations[toolkit.name].toolkit_path,
189+
))
179190
else:
180191
registrations[toolkit.name] = toolkit
181192

182-
if len(registrations) > 1:
183-
fail("multiple cuda.toolkit is not supported")
184-
185193
for _, toolkit in registrations.items():
186194
if components_mapping != None:
187195
# Always use the maximum version so the toolkit includes all components.
@@ -195,7 +203,8 @@ def _impl(module_ctx):
195203

196204
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = toolkit_version)
197205
else:
198-
cuda_toolkit(**_module_tag_to_dict(toolkit))
206+
attrs = {k: v for k, v in _module_tag_to_dict(toolkit).items() if k != "redist_json_name"}
207+
cuda_toolkit(**attrs)
199208

200209
toolchain = module_extension(
201210
implementation = _impl,

cuda/platform_alias_extension.bzl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,21 @@ def _platform_alias_repo_impl(ctx):
2626

2727
build_content.append("[")
2828
build_content.append(" config_setting(")
29-
build_content.append(' name = "version_is_{}".format(version.replace(".", "_")),')
29+
build_content.append(' name = "version_is_{}_" + version.replace(".", "_"),'.format(
30+
ctx.attr.component_name,
31+
))
3032
build_content.append(' flag_values = {"@rules_cuda//cuda:version": "{}".format(version)},')
3133
build_content.append(" )")
3234
build_content.append(" for version in {}".format(ctx.attr.versions))
3335
build_content.append("]")
3436
build_content.append("")
3537

36-
build_content.append('unsupported_cuda_version(name = "unsupported_cuda_version", component = "{}", available_versions = {})'.format(ctx.attr.component_name, ctx.attr.versions))
38+
build_content.append(
39+
'unsupported_cuda_version(name = "unsupported_cuda_version", component = "{}", available_versions = {})'.format(
40+
ctx.attr.component_name,
41+
ctx.attr.versions,
42+
),
43+
)
3744
build_content.append("")
3845

3946
# Build a target for the name of the repo (only if at least one platform is available).
@@ -50,7 +57,12 @@ def _platform_alias_repo_impl(ctx):
5057

5158
# Always create unsupported_cuda_platform target - it's used as the default case
5259
# in select() when no platform condition matches.
53-
build_content.append('unsupported_cuda_platform(name = "unsupported_cuda_platform", component = "{}", available_platforms = {})'.format(ctx.attr.component_name, platforms_available))
60+
build_content.append(
61+
'unsupported_cuda_platform(name = "unsupported_cuda_platform", component = "{}", available_platforms = {})'.format(
62+
ctx.attr.component_name,
63+
platforms_available,
64+
),
65+
)
5466
build_content.append("")
5567

5668
# Only generate target aliases if this component is in TARGET_MAPPING.
@@ -76,10 +88,13 @@ def _platform_alias_repo_impl(ctx):
7688
build_content.append("alias(")
7789
build_content.append(' name = "{}",'.format(target_name))
7890
build_content.append(" actual = select({")
91+
7992
# Add conditions for ALL platforms, using dummy for unavailable ones.
8093
for platform in SUPPORTED_PLATFORMS:
8194
platform_suffix = platform.replace("-", "_")
82-
build_content.append(' "@rules_cuda//cuda:{}_platform_is_{}":'.format(platform_type, platform_suffix))
95+
build_content.append(
96+
' "@rules_cuda//cuda:{}_platform_is_{}":'.format(platform_type, platform_suffix),
97+
)
8398
if platform in platforms_available:
8499
build_content.append(' ":{}_{}",'.format(platform_suffix, target_name))
85100
else:
@@ -106,20 +121,38 @@ def _platform_alias_repo_impl(ctx):
106121
platform_suffix = platform.replace("-", "_")
107122
repos_dict = platform_repos_map[platform]
108123
platform_available = platform in platforms_available
124+
default_version = ctx.attr.versions[0] if ctx.attr.versions else None
109125

110126
build_content.append("alias(")
111127
build_content.append(' name = "{}_{}",'.format(platform_suffix, target_name))
112128
build_content.append(" actual = select({")
113129

114130
for version in ctx.attr.versions:
115-
build_content.append(' ":version_is_{}": '.format(version.replace(".", "_")))
131+
build_content.append(' ":version_is_{}_{}": '.format(
132+
ctx.attr.component_name,
133+
version.replace(".", "_"),
134+
))
116135
if platform_available and version in repos_dict:
117136
repo_name = repos_dict[version]
118-
build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target))
137+
build_content.append(
138+
' "@{}//{}",'.format(
139+
repo_name,
140+
target if target.find(":") != -1 else ":" + target,
141+
),
142+
)
119143
else:
120144
# Platform doesn't have this component for this version, use dummy.
121145
build_content.append(' "{}",'.format(dummy_target))
122-
build_content.append(' "//conditions:default": ":unsupported_cuda_version",')
146+
if platform_available and default_version and default_version in repos_dict:
147+
repo_name = repos_dict[default_version]
148+
build_content.append(
149+
' "//conditions:default": "@{}//{}",'.format(
150+
repo_name,
151+
target if target.find(":") != -1 else ":" + target,
152+
),
153+
)
154+
else:
155+
build_content.append(' "//conditions:default": ":unsupported_cuda_version",')
123156

124157
build_content.append(" }),")
125158
build_content.append(' visibility = ["//visibility:public"],')

cuda/private/errors.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
def _unsupported_cuda_version_impl(ctx):
22
fail("CUDA component '{}' is not available for the selected CUDA version. Available versions: {}".format(
33
ctx.attr.component,
4-
", ".join(ctx.attr.available_versions)
4+
", ".join(ctx.attr.available_versions),
55
))
66

77
unsupported_cuda_version = rule(
@@ -15,7 +15,7 @@ unsupported_cuda_version = rule(
1515
def _unsupported_cuda_platform_impl(ctx):
1616
fail("CUDA component '{}' is not available for the selected platform. Available platforms: {}".format(
1717
ctx.attr.component,
18-
", ".join(ctx.attr.available_platforms)
18+
", ".join(ctx.attr.available_platforms),
1919
))
2020

2121
unsupported_cuda_platform = rule(

cuda/private/repositories.bzl

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,25 @@ def _is_linux(ctx):
1313
def _is_windows(ctx):
1414
return ctx.os.name.lower().startswith("windows")
1515

16+
def _detect_platform(ctx):
17+
os = None
18+
if _is_linux(ctx):
19+
os = "linux"
20+
elif _is_windows(ctx):
21+
os = "windows"
22+
else:
23+
fail("Unsupported OS '{}' for CUDA redist.".format(ctx.os.name))
24+
25+
arch = ctx.os.arch
26+
if arch in ["x86_64", "amd64"]:
27+
arch = "x86_64"
28+
elif arch in ["aarch64", "arm64"]:
29+
arch = "aarch64"
30+
else:
31+
fail("Unsupported arch '{}' for CUDA redist.".format(ctx.os.arch))
32+
33+
return "{}-{}".format(os, arch)
34+
1635
def _get_nvcc_version(repository_ctx, nvcc_root):
1736
result = repository_ctx.execute([nvcc_root + "/bin/nvcc", "--version"])
1837
if result.return_code != 0:
@@ -114,19 +133,18 @@ def _detect_deliverable_cuda_toolkit(repository_ctx):
114133

115134
nvcc_repo = repository_ctx.attr.components_mapping["nvcc"]
116135

117-
bin_ext = ".exe" if _is_windows(repository_ctx) else ""
118-
nvcc = "{}//:nvcc{}".format(nvcc_repo, bin_ext)
119-
nvlink = "{}//:nvlink{}".format(nvcc_repo, bin_ext)
136+
nvcc = "{}//:nvcc".format(nvcc_repo)
137+
nvlink = "{}//:nvlink".format(nvcc_repo)
120138
link_stub = "{}//:link.stub".format(nvcc_repo)
121-
bin2c = "{}//:bin2c{}".format(nvcc_repo, bin_ext)
122-
fatbinary = "{}//:fatbinary{}".format(nvcc_repo, bin_ext)
123-
ptxas = "{}//:ptxas{}".format(ptxas, bin_ext)
139+
bin2c = "{}//:bin2c".format(nvcc_repo)
140+
fatbinary = "{}//:fatbinary".format(nvcc_repo)
141+
ptxas = "{}//:ptxas".format(nvcc_repo)
124142

125143
cicc = None
126144
libdevice = None
127-
if int(cuda_version_major) >= 13:
145+
if "nvvm" in repository_ctx.attr.components_mapping:
128146
nvvm_repo = repository_ctx.attr.components_mapping["nvvm"]
129-
cicc = "{}//:cicc{}".format(nvvm_repo, bin_ext)
147+
cicc = "{}//:cicc".format(nvvm_repo)
130148
libdevice = "{}//:libdevice.10.bc".format(nvvm_repo)
131149

132150
return struct(
@@ -464,7 +482,8 @@ def _cuda_redist_json_impl(repository_ctx):
464482
attr = repository_ctx.attr
465483
url, json_object = redist_json_helper.get(repository_ctx, attr)
466484
redist_ver = redist_json_helper.get_redist_version(repository_ctx, attr, json_object)
467-
specs = redist_json_helper.collect_specs(repository_ctx, attr, json_object, url)
485+
platform = _detect_platform(repository_ctx)
486+
specs = redist_json_helper.collect_specs(repository_ctx, attr, platform, json_object, url)
468487

469488
template_helper.generate_redist_bzl(repository_ctx, specs, redist_ver)
470489
repository_ctx.symlink(Label("//cuda/private:templates/BUILD.redist_json"), "BUILD")

0 commit comments

Comments
 (0)