Skip to content

Commit 323d04f

Browse files
finn-ballFinn Ball
authored andcommitted
Allow multiple cuda versions in the same toolchain
1 parent 6fa09c3 commit 323d04f

File tree

24 files changed

+648
-64
lines changed

24 files changed

+648
-64
lines changed

.github/workflows/build-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ jobs:
9898

9999
# Override cuda redist_json rule version for CI
100100
- run: echo "CUDA_REDIST_VERSION_OVERRIDE=${{ matrix.cases.cuda-version }}" >> $GITHUB_ENV
101-
if: ${{ !startsWith(matrix.cases.os, 'windows') }}
101+
if: ${{ !startsWith(matrix.cases.os, 'windows') && matrix.cases.source == 'nvidia' }}
102102
- run: echo "CUDA_REDIST_VERSION_OVERRIDE=${{ matrix.cases.cuda-version }}" >> $env:GITHUB_ENV
103-
if: ${{ startsWith(matrix.cases.os, 'windows') }}
103+
if: ${{ startsWith(matrix.cases.os, 'windows') && matrix.cases.source == 'nvidia' }}
104104

105105
# Use Bazel with version specified in .bazelversion
106106
- run: echo "USE_BAZEL_VERSION=$(cat .bazelversion)" >> $GITHUB_ENV

README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,46 @@ use_repo(cuda, "cuda")
3232

3333
`rules_cc` provides the C++ toolchain dependency for `rules_cuda`; in Bzlmod, the compatibility repository is handled by `rules_cc` itself.
3434

35+
#### Multi-version hermetic toolchains (Bzlmod)
36+
37+
To select CUDA versions and platforms at build time, define multiple `cuda.redist_json` entries and a single `cuda.toolkit`.
38+
Then use `@rules_cuda//cuda:version`, `@rules_cuda//cuda:exec_platform`, and `@rules_cuda//cuda:target_platform` flags.
39+
40+
```starlark
41+
cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
42+
43+
cuda.redist_json(
44+
name = "cuda_13_0_2",
45+
version = "13.0.2",
46+
platforms = [
47+
"linux-x86_64",
48+
"linux-sbsa",
49+
],
50+
)
51+
cuda.redist_json(
52+
name = "cuda_13_0_0",
53+
version = "13.0.0",
54+
platforms = [
55+
"linux-x86_64",
56+
"linux-sbsa",
57+
],
58+
)
59+
60+
cuda.toolkit(name = "cuda")
61+
use_repo(cuda, "cuda")
62+
```
63+
64+
Example `.bazelrc` entries:
65+
66+
```
67+
build --@rules_cuda//cuda:exec_platform=linux-x86_64
68+
build --@rules_cuda//cuda:target_platform=linux-x86_64
69+
build --@rules_cuda//cuda:version=13.0.0
70+
```
71+
72+
Note: In Bzlmod, `platforms` is required for `cuda.redist_json` because module extensions don't have access to host OS/arch
73+
information, so the platforms must be declared explicitly.
74+
3575
<details>
3676
<summary>Traditional WORKSPACE approach</summary>
3777

cuda/BUILD.bazel

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ load(
44
"bool_flag",
55
"string_flag",
66
)
7+
load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS")
78
load("//cuda/private:rules/flags.bzl", "cuda_archs_flag", "repeatable_string_flag")
89

910
package(default_visibility = ["//visibility:public"])
@@ -38,6 +39,39 @@ config_setting(
3839
flag_values = {"@cuda//:valid_toolchain_found": "True"},
3940
)
4041

42+
string_flag(
43+
name = "version",
44+
build_setting_default = "13.0.0",
45+
)
46+
47+
string_flag(
48+
name = "target_platform",
49+
build_setting_default = "linux-x86_64",
50+
values = SUPPORTED_PLATFORMS,
51+
)
52+
53+
[
54+
config_setting(
55+
name = "target_platform_is_{}".format(platform.replace("-", "_")),
56+
flag_values = {":target_platform": platform},
57+
)
58+
for platform in SUPPORTED_PLATFORMS
59+
]
60+
61+
string_flag(
62+
name = "exec_platform",
63+
build_setting_default = "linux-x86_64",
64+
values = SUPPORTED_PLATFORMS,
65+
)
66+
67+
[
68+
config_setting(
69+
name = "exec_platform_is_{}".format(platform.replace("-", "_")),
70+
flag_values = {":exec_platform": platform},
71+
)
72+
for platform in SUPPORTED_PLATFORMS
73+
]
74+
4175
# Command line flag to specify the list of CUDA architectures to compile for.
4276
#
4377
# Provides CudaArchsInfo of the list of archs to build.

cuda/defs.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Core rules for building CUDA projects.
33
"""
44

55
load("//cuda/private:defs.bzl", _requires_cuda = "requires_cuda")
6+
load("//cuda/private:errors.bzl", _unsupported_cuda_platform = "unsupported_cuda_platform", _unsupported_cuda_version = "unsupported_cuda_version")
67
load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary")
78
load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test")
89
load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows")
@@ -47,3 +48,5 @@ if_windows = _if_windows
4748
cc_import_versioned_sos = _cc_import_versioned_sos
4849

4950
requires_cuda = _requires_cuda
51+
unsupported_cuda_version = _unsupported_cuda_version
52+
unsupported_cuda_platform = _unsupported_cuda_platform

cuda/dummy/BUILD.bazel

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
load("@rules_cc//cc:defs.bzl", "cc_binary")
1+
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
22

33
package(default_visibility = ["//visibility:public"])
44

@@ -8,13 +8,22 @@ cc_binary(
88
defines = ["TOOLNAME=nvcc"],
99
)
1010

11+
cc_binary(
12+
name = "cicc",
13+
srcs = ["dummy.cpp"],
14+
defines = ["TOOLNAME=cicc"],
15+
)
16+
1117
cc_binary(
1218
name = "nvlink",
1319
srcs = ["dummy.cpp"],
1420
defines = ["TOOLNAME=nvlink"],
1521
)
1622

17-
exports_files(["link.stub"])
23+
exports_files([
24+
"link.stub",
25+
"libdevice.10.bc",
26+
])
1827

1928
cc_binary(
2029
name = "bin2c",
@@ -28,6 +37,13 @@ cc_binary(
2837
defines = ["TOOLNAME=fatbinary"],
2938
)
3039

40+
# Empty cc_library that provides CcInfo for components not available in this CUDA version.
41+
cc_library(
42+
name = "dummy",
43+
srcs = [],
44+
hdrs = [],
45+
)
46+
3147
cc_binary(
3248
name = "ptxas",
3349
srcs = ["dummy.cpp"],

cuda/dummy/libdevice.10.bc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#error libdevice.10.bc of cuda toolkit does not exist

cuda/extensions.bzl

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Entry point for extensions used by bzlmod."""
22

3+
load("//cuda:platform_alias_extension.bzl", "platform_alias_repo")
34
load("//cuda/private:redist_json_helper.bzl", "redist_json_helper")
45
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit")
56

@@ -53,6 +54,9 @@ cuda_redist_json_tag = tag_class(attrs = {
5354
"URLs are tried in order until one succeeds, so you should list local mirrors first. " +
5455
"If all downloads fail, the rule will fail.",
5556
),
57+
"platforms": attr.string_list(
58+
doc = "A list of platforms to generate components for.",
59+
),
5660
"version": attr.string(
5761
doc = "Generate a URL by using the specified version." +
5862
"This URL will be tried after all URLs specified in the `urls` attribute.",
@@ -72,6 +76,10 @@ cuda_toolkit_tag = tag_class(attrs = {
7276
"nvcc_version": attr.string(
7377
doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.",
7478
),
79+
"redist_json_name": attr.string(
80+
doc = "Name of the redist_json tag whose components this toolkit should use. " +
81+
"If omitted and exactly one redist_json is declared, it is used automatically.",
82+
),
7583
})
7684

7785
def _find_modules(module_ctx):
@@ -95,17 +103,21 @@ def _module_tag_to_dict(t):
95103
def _redist_json_impl(module_ctx, attr):
96104
url, json_object = redist_json_helper.get(module_ctx, attr)
97105
redist_ver = redist_json_helper.get_redist_version(module_ctx, attr, json_object)
98-
component_specs = redist_json_helper.collect_specs(module_ctx, attr, json_object, url)
99-
100-
mapping = {}
101-
for spec in component_specs:
102-
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
103-
mapping[spec["component_name"]] = "@" + repo_name
104106

105-
attr = {key: value for key, value in spec.items()}
106-
attr["name"] = repo_name
107-
cuda_component(**attr)
108-
return redist_ver, mapping
107+
platform_mapping = {}
108+
for platform in attr.platforms:
109+
component_specs = redist_json_helper.collect_specs(module_ctx, attr, platform, json_object, url)
110+
mapping = {}
111+
for spec in component_specs:
112+
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
113+
mapping[spec["component_name"]] = repo_name
114+
115+
component_attr = {key: value for key, value in spec.items()}
116+
component_repo_name = repo_name + "_" + attr.name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_")
117+
component_attr["name"] = component_repo_name
118+
cuda_component(**component_attr)
119+
platform_mapping[platform] = mapping
120+
return redist_ver, platform_mapping
109121

110122
def _impl(module_ctx):
111123
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
@@ -125,32 +137,74 @@ def _impl(module_ctx):
125137
for component in components:
126138
cuda_component(**_module_tag_to_dict(component))
127139

128-
if len(redist_jsons) > 1:
129-
fail("Using multiple cuda.redist_json is not supported yet.")
130-
131140
redist_version = None
132141
components_mapping = None
142+
redist_versions = []
143+
redist_components_mapping = {}
144+
145+
# Track all versioned repositories for each component and platform.
146+
versioned_repos = {}
133147
for redist_json in redist_jsons:
134-
redist_version, components_mapping = _redist_json_impl(module_ctx, redist_json)
148+
components_mapping = {}
149+
redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json)
150+
if redist_version not in redist_versions:
151+
redist_versions.append(redist_version)
152+
for platform in platform_mapping.keys():
153+
for component_name, repo_name in platform_mapping[platform].items():
154+
redist_components_mapping[component_name] = repo_name
155+
156+
# Track the versioned repo name for this component/platform/version.
157+
if component_name not in versioned_repos:
158+
versioned_repos[component_name] = {}
159+
if platform not in versioned_repos[component_name]:
160+
versioned_repos[component_name][platform] = {}
161+
versioned_repos[component_name][platform][redist_version] = repo_name + "_" + redist_json.name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_")
162+
163+
for component_name in redist_components_mapping.keys():
164+
# Build dictionaries mapping versions to repo names for each platform.
165+
x86_64_repos = {ver: versioned_repos[component_name]["linux-x86_64"][ver] for ver in redist_versions if "linux-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-x86_64"]}
166+
aarch64_repos = {ver: versioned_repos[component_name]["linux-aarch64"][ver] for ver in redist_versions if "linux-aarch64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-aarch64"]}
167+
sbsa_repos = {ver: versioned_repos[component_name]["linux-sbsa"][ver] for ver in redist_versions if "linux-sbsa" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-sbsa"]}
168+
169+
platform_alias_repo(
170+
name = redist_components_mapping[component_name],
171+
repo_name = redist_components_mapping[component_name],
172+
component_name = component_name,
173+
linux_x86_64_repos = x86_64_repos,
174+
linux_aarch64_repos = aarch64_repos,
175+
linux_sbsa_repos = sbsa_repos,
176+
versions = redist_versions,
177+
)
178+
components_mapping[component_name] = "@" + redist_components_mapping[component_name]
135179

136180
registrations = {}
137181
for toolkit in toolkits:
138182
if toolkit.name in registrations.keys():
139183
if toolkit.toolkit_path == registrations[toolkit.name].toolkit_path:
140-
# No problem to register a matching toolkit twice
141184
continue
142-
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(toolkit.name, toolkit.toolkit_path, registrations[toolkit.name].toolkit_path))
185+
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(
186+
toolkit.name,
187+
toolkit.toolkit_path,
188+
registrations[toolkit.name].toolkit_path,
189+
))
143190
else:
144191
registrations[toolkit.name] = toolkit
145192

146-
if len(registrations) > 1:
147-
fail("multiple cuda.toolkit is not supported")
148-
149193
for _, toolkit in registrations.items():
150194
if components_mapping != None:
151-
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = redist_version)
195+
# Always use the maximum version so the toolkit includes all components.
196+
# Components that don't exist in older versions will fall back to dummy.
197+
toolkit_version = redist_versions[0]
198+
for ver in redist_versions:
199+
ver_parts = [int(x) for x in ver.split(".")]
200+
tv_parts = [int(x) for x in toolkit_version.split(".")]
201+
if ver_parts > tv_parts:
202+
toolkit_version = ver
203+
204+
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = toolkit_version)
152205
else:
153-
cuda_toolkit(**_module_tag_to_dict(toolkit))
206+
attrs = {k: v for k, v in _module_tag_to_dict(toolkit).items() if k != "redist_json_name"}
207+
cuda_toolkit(**attrs)
154208

155209
toolchain = module_extension(
156210
implementation = _impl,

0 commit comments

Comments
 (0)