Skip to content

Commit 4a85605

Browse files
author
finn-ball
committed
Refactor the platform and the dedupe logic
Make no version in the config default to the highest set.
1 parent 8d87bb8 commit 4a85605

File tree

6 files changed

+116
-90
lines changed

6 files changed

+116
-90
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use_repo(cuda, "cuda")
3636

3737
To select CUDA versions and platforms at build time, define multiple `cuda.redist_json` entries and a single `cuda.toolkit`.
3838
Then use `@rules_cuda//cuda:version`, `@rules_cuda//cuda:exec_platform`, and `@rules_cuda//cuda:target_platform` flags.
39+
If `@rules_cuda//cuda:version` is unset, rules_cuda selects the highest declared redist version.
3940

4041
```starlark
4142
cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
@@ -66,6 +67,7 @@ Example `.bazelrc` entries:
6667
```
6768
build --@rules_cuda//cuda:exec_platform=linux-x86_64
6869
build --@rules_cuda//cuda:target_platform=linux-x86_64
70+
# Optional: if omitted, the highest declared redist version is used.
6971
build --@rules_cuda//cuda:version=13.0.0
7072
```
7173

cuda/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@ config_setting(
3939
flag_values = {"@cuda//:valid_toolchain_found": "True"},
4040
)
4141

42+
# NOTE: Functional with platform_alias only.
4243
string_flag(
4344
name = "version",
44-
build_setting_default = "13.0.0",
45+
build_setting_default = "",
4546
)
4647

48+
# NOTE: Functional with platform_alias only.
4749
string_flag(
4850
name = "target_platform",
4951
build_setting_default = "linux-x86_64",
@@ -58,6 +60,7 @@ string_flag(
5860
for platform in SUPPORTED_PLATFORMS
5961
]
6062

63+
# NOTE: Functional with platform_alias only.
6164
string_flag(
6265
name = "exec_platform",
6366
build_setting_default = "linux-x86_64",

cuda/extensions.bzl

Lines changed: 79 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Entry point for extensions used by bzlmod."""
22

33
load("//cuda:platform_alias_extension.bzl", "platform_alias_repo")
4+
load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS")
45
load("//cuda/private:redist_json_helper.bzl", "redist_json_helper")
56
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit")
67

@@ -96,6 +97,16 @@ def _find_modules(module_ctx):
9697
def _module_tag_to_dict(t):
9798
return {attr: getattr(t, attr) for attr in dir(t)}
9899

100+
def _platform_repos_attr(platform):
101+
return platform.replace("-", "_") + "_repos"
102+
103+
def _version_sort_key(version):
104+
prefix = version.split("-", 1)[0]
105+
parts = prefix.split(".")
106+
if all([p.isdigit() for p in parts]):
107+
return (1, [int(p) for p in parts], version)
108+
return (0, [], version)
109+
99110
def _component_attrs_match(existing, current):
100111
for key, value in current.items():
101112
if key == "name":
@@ -109,36 +120,43 @@ def _component_attrs_match(existing, current):
109120
return False
110121
return True
111122

112-
def _redist_json_impl(module_ctx, attr, generated_components):
123+
def _component_entry_key(component_name, platform, redist_ver):
124+
return "{}|{}|{}".format(component_name, platform, redist_ver)
125+
126+
def _register_redist_components(module_ctx, attr, component_entries):
113127
url, json_object = redist_json_helper.get(module_ctx, attr)
114128
redist_ver = redist_json_helper.get_redist_version(module_ctx, attr, json_object)
115129

116-
platform_mapping = {}
117130
for platform in attr.platforms:
118131
component_specs = redist_json_helper.collect_specs(module_ctx, attr, platform, json_object, url)
119-
mapping = {}
120132
for spec in component_specs:
121133
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
122-
mapping[spec["component_name"]] = repo_name
123134

124135
component_attr = {key: value for key, value in spec.items()}
125136
component_repo_name = repo_name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_")
126137
component_attr["name"] = component_repo_name
127138

128-
dedupe_key = "{}|{}|{}".format(spec["component_name"], platform, redist_ver)
129-
existing_attr = generated_components.get(dedupe_key)
139+
dedupe_key = _component_entry_key(spec["component_name"], platform, redist_ver)
140+
existing_entry = component_entries.get(dedupe_key)
141+
existing_attr = existing_entry["component_attr"] if existing_entry else None
130142
if existing_attr == None:
131143
cuda_component(**component_attr)
132-
generated_components[dedupe_key] = component_attr
144+
component_entries[dedupe_key] = {
145+
"component_name": spec["component_name"],
146+
"platform": platform,
147+
"redist_version": redist_ver,
148+
"repo_name": repo_name,
149+
"generated_repo_name": component_repo_name,
150+
"component_attr": component_attr,
151+
}
133152
elif not _component_attrs_match(existing_attr, component_attr):
134153
fail(("Conflicting CUDA component definition for {} on {} at version {}. " +
135154
"Use distinct component versions when registries are not identical.").format(
136155
spec["component_name"],
137156
platform,
138157
redist_ver,
139158
))
140-
platform_mapping[platform] = mapping
141-
return redist_ver, platform_mapping
159+
return redist_ver
142160

143161
def _impl(module_ctx):
144162
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
@@ -158,47 +176,62 @@ def _impl(module_ctx):
158176
for component in components:
159177
cuda_component(**_module_tag_to_dict(component))
160178

161-
redist_version = None
162179
components_mapping = None
163180
redist_versions = []
164-
redist_components_mapping = {}
165-
166-
# Track all versioned repositories for each component and platform.
167-
versioned_repos = {}
168-
generated_components = {}
181+
component_entries = {}
169182
for redist_json in redist_jsons:
170-
components_mapping = {}
171-
redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json, generated_components)
183+
redist_version = _register_redist_components(module_ctx, redist_json, component_entries)
172184
if redist_version not in redist_versions:
173185
redist_versions.append(redist_version)
174-
for platform in platform_mapping.keys():
175-
for component_name, repo_name in platform_mapping[platform].items():
176-
redist_components_mapping[component_name] = repo_name
177-
178-
# Track the versioned repo name for this component/platform/version.
179-
if component_name not in versioned_repos:
180-
versioned_repos[component_name] = {}
181-
if platform not in versioned_repos[component_name]:
182-
versioned_repos[component_name][platform] = {}
183-
versioned_repos[component_name][platform][redist_version] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_")
184-
185-
for component_name in redist_components_mapping.keys():
186-
# Build dictionaries mapping versions to repo names for each platform.
187-
x86_64_repos = {ver: versioned_repos[component_name]["linux-x86_64"][ver] for ver in redist_versions if "linux-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-x86_64"]}
188-
windows_x86_64_repos = {ver: versioned_repos[component_name]["windows-x86_64"][ver] for ver in redist_versions if "windows-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["windows-x86_64"]}
189-
aarch64_repos = {ver: versioned_repos[component_name]["linux-aarch64"][ver] for ver in redist_versions if "linux-aarch64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-aarch64"]}
190-
sbsa_repos = {ver: versioned_repos[component_name]["linux-sbsa"][ver] for ver in redist_versions if "linux-sbsa" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-sbsa"]}
191-
192-
platform_alias_repo(
193-
name = redist_components_mapping[component_name],
194-
component_name = component_name,
195-
linux_x86_64_repos = x86_64_repos,
196-
windows_x86_64_repos = windows_x86_64_repos,
197-
linux_aarch64_repos = aarch64_repos,
198-
linux_sbsa_repos = sbsa_repos,
199-
versions = redist_versions,
200-
)
201-
components_mapping[component_name] = "@" + redist_components_mapping[component_name]
186+
187+
if len(component_entries) > 0:
188+
components_mapping = {}
189+
redist_components_mapping = {}
190+
versioned_repos = {}
191+
for entry in component_entries.values():
192+
component_name = entry["component_name"]
193+
platform = entry["platform"]
194+
redist_version = entry["redist_version"]
195+
196+
redist_components_mapping[component_name] = entry["repo_name"]
197+
if component_name not in versioned_repos:
198+
versioned_repos[component_name] = {}
199+
if platform not in versioned_repos[component_name]:
200+
versioned_repos[component_name][platform] = {}
201+
versioned_repos[component_name][platform][redist_version] = entry["generated_repo_name"]
202+
203+
for component_name in redist_components_mapping.keys():
204+
component_platforms = [
205+
platform
206+
for platform in SUPPORTED_PLATFORMS
207+
if platform in versioned_repos[component_name] and len(versioned_repos[component_name][platform]) > 0
208+
]
209+
# Preserve pre-multi-version behavior for the simple case:
210+
# if there is exactly one concrete repo, wire toolkit mapping directly.
211+
if len(redist_versions) == 1 and len(component_platforms) == 1:
212+
only_platform = component_platforms[0]
213+
only_version = redist_versions[0]
214+
only_repo = versioned_repos[component_name][only_platform].get(only_version)
215+
if only_repo:
216+
components_mapping[component_name] = "@" + only_repo
217+
continue
218+
219+
# Build dictionaries mapping versions to repo names for each platform.
220+
platform_repo_kwargs = {}
221+
for platform in SUPPORTED_PLATFORMS:
222+
platform_repo_kwargs[_platform_repos_attr(platform)] = {
223+
ver: versioned_repos[component_name][platform][ver]
224+
for ver in redist_versions
225+
if platform in versioned_repos[component_name] and ver in versioned_repos[component_name][platform]
226+
}
227+
228+
platform_alias_repo(
229+
name = redist_components_mapping[component_name],
230+
component_name = component_name,
231+
versions = redist_versions,
232+
**platform_repo_kwargs
233+
)
234+
components_mapping[component_name] = "@" + redist_components_mapping[component_name]
202235

203236
registrations = {}
204237
for toolkit in toolkits:
@@ -217,13 +250,7 @@ def _impl(module_ctx):
217250
if components_mapping != None:
218251
# Always use the maximum version so the toolkit includes all components.
219252
# Components that don't exist in older versions will fall back to dummy.
220-
toolkit_version = redist_versions[0]
221-
for ver in redist_versions:
222-
ver_parts = [int(x) for x in ver.split(".")]
223-
tv_parts = [int(x) for x in toolkit_version.split(".")]
224-
if ver_parts > tv_parts:
225-
toolkit_version = ver
226-
253+
toolkit_version = sorted(redist_versions, key = _version_sort_key)[-1]
227254
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = toolkit_version)
228255
else:
229256
cuda_toolkit(**_module_tag_to_dict(toolkit))

cuda/platform_alias_extension.bzl

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@ load("//cuda/private:templates/registry.bzl", "REGISTRY")
1010
# Use REGISTRY as the source of truth for component targets
1111
TARGET_MAPPING = REGISTRY
1212

13+
def _platform_repos_attr(platform):
14+
return platform.replace("-", "_") + "_repos"
15+
16+
_PLATFORM_REPO_ATTRS = {
17+
_platform_repos_attr(_platform): attr.string_dict(
18+
default = {},
19+
doc = "Dictionary mapping versions to repository names for {}".format(_platform),
20+
)
21+
for _platform in SUPPORTED_PLATFORMS
22+
}
23+
24+
def _version_sort_key(version):
25+
prefix = version.split("-", 1)[0]
26+
parts = prefix.split(".")
27+
if all([p.isdigit() for p in parts]):
28+
return (1, [int(p) for p in parts], version)
29+
return (0, [], version)
30+
1331
def _platform_alias_repo_impl(ctx):
1432
"""Implementation of the platform_alias_repo repository rule.
1533
@@ -46,16 +64,13 @@ def _platform_alias_repo_impl(ctx):
4664
# Build a target for the name of the repo (only if at least one platform is available).
4765
platform_type = "exec" if ctx.attr.component_name in ["nvcc", "nvvm"] else "target"
4866

67+
platform_repos_map = {
68+
platform: getattr(ctx.attr, _platform_repos_attr(platform))
69+
for platform in SUPPORTED_PLATFORMS
70+
}
71+
4972
# Check which platforms are available (have at least one version).
50-
platforms_available = []
51-
if len(ctx.attr.linux_x86_64_repos) > 0:
52-
platforms_available.append("linux-x86_64")
53-
if len(ctx.attr.windows_x86_64_repos) > 0:
54-
platforms_available.append("windows-x86_64")
55-
if len(ctx.attr.linux_sbsa_repos) > 0:
56-
platforms_available.append("linux-sbsa")
57-
if len(ctx.attr.linux_aarch64_repos) > 0:
58-
platforms_available.append("linux-aarch64")
73+
platforms_available = [platform for platform in SUPPORTED_PLATFORMS if len(platform_repos_map[platform]) > 0]
5974

6075
# Always create unsupported_cuda_platform target - it's used as the default case
6176
# in select() when no platform condition matches.
@@ -113,18 +128,11 @@ def _platform_alias_repo_impl(ctx):
113128
# Platforms where it doesn't exist get dummy targets for all versions.
114129
# This ensures builds on any platform have matching select conditions.
115130

116-
platform_repos_map = {
117-
"linux-x86_64": ctx.attr.linux_x86_64_repos,
118-
"windows-x86_64": ctx.attr.windows_x86_64_repos,
119-
"linux-sbsa": ctx.attr.linux_sbsa_repos,
120-
"linux-aarch64": ctx.attr.linux_aarch64_repos,
121-
}
122-
123131
for platform in SUPPORTED_PLATFORMS:
124132
platform_suffix = platform.replace("-", "_")
125133
repos_dict = platform_repos_map[platform]
126134
platform_available = platform in platforms_available
127-
default_version = ctx.attr.versions[0] if ctx.attr.versions else None
135+
default_version = sorted(ctx.attr.versions, key = _version_sort_key)[-1] if ctx.attr.versions else None
128136

129137
build_content.append("alias(")
130138
build_content.append(' name = "{}_{}",'.format(platform_suffix, target_name))
@@ -167,30 +175,14 @@ def _platform_alias_repo_impl(ctx):
167175

168176
platform_alias_repo = repository_rule(
169177
implementation = _platform_alias_repo_impl,
170-
attrs = {
178+
attrs = dict({
171179
"component_name": attr.string(
172180
mandatory = True,
173181
doc = "Name of the component",
174182
),
175-
"linux_x86_64_repos": attr.string_dict(
176-
default = {},
177-
doc = "Dictionary mapping versions to x86_64 repository names",
178-
),
179-
"linux_aarch64_repos": attr.string_dict(
180-
default = {},
181-
doc = "Dictionary mapping versions to ARM64/Jetpack repository names",
182-
),
183-
"windows_x86_64_repos": attr.string_dict(
184-
default = {},
185-
doc = "Dictionary mapping versions to Windows x86_64 repository names",
186-
),
187-
"linux_sbsa_repos": attr.string_dict(
188-
default = {},
189-
doc = "Dictionary mapping versions to SBSA repository names",
190-
),
191183
"versions": attr.string_list(
192184
mandatory = True,
193185
doc = "List of versions to create aliases for",
194186
),
195-
},
187+
}, **_PLATFORM_REPO_ATTRS),
196188
)

cuda/private/toolchain_configs/clang.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def _impl(ctx):
111111

112112
nvvm_root = None
113113
if libdevice_dir:
114+
# Clang expects --cuda-path to point at a CTK root containing
115+
# bin/, include/, and nvvm/libdevice/libdevice.10.bc.
114116
nvvm_root = paths.dirname(paths.dirname(libdevice_dir))
115-
if nvvm_root.endswith("/nvvm"):
116-
nvvm_root = paths.dirname(nvvm_root)
117117
cuda_root = ctx.attr.cuda_toolkit[CudaToolkitInfo].path
118118
if (not cuda_root or cuda_root == "cuda-not-found") and ctx.attr.cuda_toolkit[CudaToolkitInfo].nvlink:
119119
# For deliverable toolkits, infer CUDA root from nvlink location.

tests/hermeticity/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ hermetic_flags=(\
2929
--@rules_cuda//cuda:copts=-Xcompiler=-fdebug-prefix-map=$(pwd)=. \
3030
--@rules_cuda//cuda:copts=-objtemp)
3131

32+
bazel clean
33+
3234
bazel build "${hermetic_flags[@]}" :cuda_test
3335
build_output1=$(strings ${file})
3436

0 commit comments

Comments
 (0)