Skip to content

Commit 3ba2845

Browse files
author
finn-ball
committed
Refactor the platform and the dedupe logic
1 parent 8d87bb8 commit 3ba2845

File tree

2 files changed

+74
-79
lines changed

2 files changed

+74
-79
lines changed

cuda/extensions.bzl

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Entry point for extensions used by bzlmod."""
22

33
load("//cuda:platform_alias_extension.bzl", "platform_alias_repo")
4+
load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS")
45
load("//cuda/private:redist_json_helper.bzl", "redist_json_helper")
56
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit")
67

@@ -96,6 +97,9 @@ def _find_modules(module_ctx):
9697
def _module_tag_to_dict(t):
9798
return {attr: getattr(t, attr) for attr in dir(t)}
9899

100+
def _platform_repos_attr(platform):
101+
return platform.replace("-", "_") + "_repos"
102+
99103
def _component_attrs_match(existing, current):
100104
for key, value in current.items():
101105
if key == "name":
@@ -109,36 +113,43 @@ def _component_attrs_match(existing, current):
109113
return False
110114
return True
111115

112-
def _redist_json_impl(module_ctx, attr, generated_components):
116+
def _component_entry_key(component_name, platform, redist_ver):
117+
return "{}|{}|{}".format(component_name, platform, redist_ver)
118+
119+
def _register_redist_components(module_ctx, attr, component_entries):
113120
url, json_object = redist_json_helper.get(module_ctx, attr)
114121
redist_ver = redist_json_helper.get_redist_version(module_ctx, attr, json_object)
115122

116-
platform_mapping = {}
117123
for platform in attr.platforms:
118124
component_specs = redist_json_helper.collect_specs(module_ctx, attr, platform, json_object, url)
119-
mapping = {}
120125
for spec in component_specs:
121126
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
122-
mapping[spec["component_name"]] = repo_name
123127

124128
component_attr = {key: value for key, value in spec.items()}
125129
component_repo_name = repo_name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_")
126130
component_attr["name"] = component_repo_name
127131

128-
dedupe_key = "{}|{}|{}".format(spec["component_name"], platform, redist_ver)
129-
existing_attr = generated_components.get(dedupe_key)
132+
dedupe_key = _component_entry_key(spec["component_name"], platform, redist_ver)
133+
existing_entry = component_entries.get(dedupe_key)
134+
existing_attr = existing_entry["component_attr"] if existing_entry else None
130135
if existing_attr == None:
131136
cuda_component(**component_attr)
132-
generated_components[dedupe_key] = component_attr
137+
component_entries[dedupe_key] = {
138+
"component_name": spec["component_name"],
139+
"platform": platform,
140+
"redist_version": redist_ver,
141+
"repo_name": repo_name,
142+
"generated_repo_name": component_repo_name,
143+
"component_attr": component_attr,
144+
}
133145
elif not _component_attrs_match(existing_attr, component_attr):
134146
fail(("Conflicting CUDA component definition for {} on {} at version {}. " +
135147
"Use distinct component versions when registries are not identical.").format(
136148
spec["component_name"],
137149
platform,
138150
redist_ver,
139151
))
140-
platform_mapping[platform] = mapping
141-
return redist_ver, platform_mapping
152+
return redist_ver
142153

143154
def _impl(module_ctx):
144155
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
@@ -158,47 +169,47 @@ def _impl(module_ctx):
158169
for component in components:
159170
cuda_component(**_module_tag_to_dict(component))
160171

161-
redist_version = None
162172
components_mapping = None
163173
redist_versions = []
164-
redist_components_mapping = {}
165-
166-
# Track all versioned repositories for each component and platform.
167-
versioned_repos = {}
168-
generated_components = {}
174+
component_entries = {}
169175
for redist_json in redist_jsons:
170-
components_mapping = {}
171-
redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json, generated_components)
176+
redist_version = _register_redist_components(module_ctx, redist_json, component_entries)
172177
if redist_version not in redist_versions:
173178
redist_versions.append(redist_version)
174-
for platform in platform_mapping.keys():
175-
for component_name, repo_name in platform_mapping[platform].items():
176-
redist_components_mapping[component_name] = repo_name
177-
178-
# Track the versioned repo name for this component/platform/version.
179-
if component_name not in versioned_repos:
180-
versioned_repos[component_name] = {}
181-
if platform not in versioned_repos[component_name]:
182-
versioned_repos[component_name][platform] = {}
183-
versioned_repos[component_name][platform][redist_version] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_")
184-
185-
for component_name in redist_components_mapping.keys():
186-
# Build dictionaries mapping versions to repo names for each platform.
187-
x86_64_repos = {ver: versioned_repos[component_name]["linux-x86_64"][ver] for ver in redist_versions if "linux-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-x86_64"]}
188-
windows_x86_64_repos = {ver: versioned_repos[component_name]["windows-x86_64"][ver] for ver in redist_versions if "windows-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["windows-x86_64"]}
189-
aarch64_repos = {ver: versioned_repos[component_name]["linux-aarch64"][ver] for ver in redist_versions if "linux-aarch64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-aarch64"]}
190-
sbsa_repos = {ver: versioned_repos[component_name]["linux-sbsa"][ver] for ver in redist_versions if "linux-sbsa" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-sbsa"]}
191-
192-
platform_alias_repo(
193-
name = redist_components_mapping[component_name],
194-
component_name = component_name,
195-
linux_x86_64_repos = x86_64_repos,
196-
windows_x86_64_repos = windows_x86_64_repos,
197-
linux_aarch64_repos = aarch64_repos,
198-
linux_sbsa_repos = sbsa_repos,
199-
versions = redist_versions,
200-
)
201-
components_mapping[component_name] = "@" + redist_components_mapping[component_name]
179+
180+
if len(component_entries) > 0:
181+
components_mapping = {}
182+
redist_components_mapping = {}
183+
versioned_repos = {}
184+
for entry in component_entries.values():
185+
component_name = entry["component_name"]
186+
platform = entry["platform"]
187+
redist_version = entry["redist_version"]
188+
189+
redist_components_mapping[component_name] = entry["repo_name"]
190+
if component_name not in versioned_repos:
191+
versioned_repos[component_name] = {}
192+
if platform not in versioned_repos[component_name]:
193+
versioned_repos[component_name][platform] = {}
194+
versioned_repos[component_name][platform][redist_version] = entry["generated_repo_name"]
195+
196+
for component_name in redist_components_mapping.keys():
197+
# Build dictionaries mapping versions to repo names for each platform.
198+
platform_repo_kwargs = {}
199+
for platform in SUPPORTED_PLATFORMS:
200+
platform_repo_kwargs[_platform_repos_attr(platform)] = {
201+
ver: versioned_repos[component_name][platform][ver]
202+
for ver in redist_versions
203+
if platform in versioned_repos[component_name] and ver in versioned_repos[component_name][platform]
204+
}
205+
206+
platform_alias_repo(
207+
name = redist_components_mapping[component_name],
208+
component_name = component_name,
209+
versions = redist_versions,
210+
**platform_repo_kwargs
211+
)
212+
components_mapping[component_name] = "@" + redist_components_mapping[component_name]
202213

203214
registrations = {}
204215
for toolkit in toolkits:

cuda/platform_alias_extension.bzl

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ load("//cuda/private:templates/registry.bzl", "REGISTRY")
1010
# Use REGISTRY as the source of truth for component targets
1111
TARGET_MAPPING = REGISTRY
1212

13+
def _platform_repos_attr(platform):
14+
return platform.replace("-", "_") + "_repos"
15+
16+
_PLATFORM_REPO_ATTRS = {
17+
_platform_repos_attr(_platform): attr.string_dict(
18+
default = {},
19+
doc = "Dictionary mapping versions to repository names for {}".format(_platform),
20+
)
21+
for _platform in SUPPORTED_PLATFORMS
22+
}
1323
def _platform_alias_repo_impl(ctx):
1424
"""Implementation of the platform_alias_repo repository rule.
1525
@@ -46,16 +56,13 @@ def _platform_alias_repo_impl(ctx):
4656
# Build a target for the name of the repo (only if at least one platform is available).
4757
platform_type = "exec" if ctx.attr.component_name in ["nvcc", "nvvm"] else "target"
4858

59+
platform_repos_map = {
60+
platform: getattr(ctx.attr, _platform_repos_attr(platform))
61+
for platform in SUPPORTED_PLATFORMS
62+
}
63+
4964
# Check which platforms are available (have at least one version).
50-
platforms_available = []
51-
if len(ctx.attr.linux_x86_64_repos) > 0:
52-
platforms_available.append("linux-x86_64")
53-
if len(ctx.attr.windows_x86_64_repos) > 0:
54-
platforms_available.append("windows-x86_64")
55-
if len(ctx.attr.linux_sbsa_repos) > 0:
56-
platforms_available.append("linux-sbsa")
57-
if len(ctx.attr.linux_aarch64_repos) > 0:
58-
platforms_available.append("linux-aarch64")
65+
platforms_available = [platform for platform in SUPPORTED_PLATFORMS if len(platform_repos_map[platform]) > 0]
5966

6067
# Always create unsupported_cuda_platform target - it's used as the default case
6168
# in select() when no platform condition matches.
@@ -113,13 +120,6 @@ def _platform_alias_repo_impl(ctx):
113120
# Platforms where it doesn't exist get dummy targets for all versions.
114121
# This ensures builds on any platform have matching select conditions.
115122

116-
platform_repos_map = {
117-
"linux-x86_64": ctx.attr.linux_x86_64_repos,
118-
"windows-x86_64": ctx.attr.windows_x86_64_repos,
119-
"linux-sbsa": ctx.attr.linux_sbsa_repos,
120-
"linux-aarch64": ctx.attr.linux_aarch64_repos,
121-
}
122-
123123
for platform in SUPPORTED_PLATFORMS:
124124
platform_suffix = platform.replace("-", "_")
125125
repos_dict = platform_repos_map[platform]
@@ -167,30 +167,14 @@ def _platform_alias_repo_impl(ctx):
167167

168168
platform_alias_repo = repository_rule(
169169
implementation = _platform_alias_repo_impl,
170-
attrs = {
170+
attrs = dict({
171171
"component_name": attr.string(
172172
mandatory = True,
173173
doc = "Name of the component",
174174
),
175-
"linux_x86_64_repos": attr.string_dict(
176-
default = {},
177-
doc = "Dictionary mapping versions to x86_64 repository names",
178-
),
179-
"linux_aarch64_repos": attr.string_dict(
180-
default = {},
181-
doc = "Dictionary mapping versions to ARM64/Jetpack repository names",
182-
),
183-
"windows_x86_64_repos": attr.string_dict(
184-
default = {},
185-
doc = "Dictionary mapping versions to Windows x86_64 repository names",
186-
),
187-
"linux_sbsa_repos": attr.string_dict(
188-
default = {},
189-
doc = "Dictionary mapping versions to SBSA repository names",
190-
),
191175
"versions": attr.string_list(
192176
mandatory = True,
193177
doc = "List of versions to create aliases for",
194178
),
195-
},
179+
}, **_PLATFORM_REPO_ATTRS),
196180
)

0 commit comments

Comments
 (0)