Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ jobs:

# Override cuda redist_json rule version for CI
- run: echo "CUDA_REDIST_VERSION_OVERRIDE=${{ matrix.cases.cuda-version }}" >> $GITHUB_ENV
if: ${{ !startsWith(matrix.cases.os, 'windows') }}
if: ${{ !startsWith(matrix.cases.os, 'windows') && matrix.cases.source == 'nvidia' }}
- run: echo "CUDA_REDIST_VERSION_OVERRIDE=${{ matrix.cases.cuda-version }}" >> $env:GITHUB_ENV
if: ${{ startsWith(matrix.cases.os, 'windows') }}
if: ${{ startsWith(matrix.cases.os, 'windows') && matrix.cases.source == 'nvidia' }}

# Use Bazel with version specified in .bazelversion
- run: echo "USE_BAZEL_VERSION=$(cat .bazelversion)" >> $GITHUB_ENV
Expand Down
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,48 @@ use_repo(cuda, "cuda")

`rules_cc` provides the C++ toolchain dependency for `rules_cuda`; in Bzlmod, the compatibility repository is handled by `rules_cc` itself.

#### Multi-version hermetic toolchains (Bzlmod)

To select CUDA versions and platforms at build time, define multiple `cuda.redist_json` entries and a single `cuda.toolkit`.
Then use `@rules_cuda//cuda:version`, `@rules_cuda//cuda:exec_platform`, and `@rules_cuda//cuda:target_platform` flags.
If `@rules_cuda//cuda:version` is unset, rules_cuda selects the highest declared redist version.

```starlark
cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")

cuda.redist_json(
name = "cuda_13_0_2",
version = "13.0.2",
platforms = [
"linux-x86_64",
"linux-sbsa",
],
)
cuda.redist_json(
name = "cuda_13_0_0",
version = "13.0.0",
platforms = [
"linux-x86_64",
"linux-sbsa",
],
)

cuda.toolkit(name = "cuda")
use_repo(cuda, "cuda")
```

Example `.bazelrc` entries:

```
build --@rules_cuda//cuda:exec_platform=linux-x86_64
build --@rules_cuda//cuda:target_platform=linux-x86_64
# Optional: if omitted, the highest declared redist version is used.
build --@rules_cuda//cuda:version=13.0.0
```

Note: In Bzlmod, `platforms` is required for `cuda.redist_json` because module extensions don't have access to host OS/arch
information, so the platforms must be declared explicitly.

<details>
<summary>Traditional WORKSPACE approach</summary>

Expand Down
37 changes: 37 additions & 0 deletions cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ load(
"bool_flag",
"string_flag",
)
load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS")
load("//cuda/private:rules/flags.bzl", "cuda_archs_flag", "repeatable_string_flag")

package(default_visibility = ["//visibility:public"])
Expand Down Expand Up @@ -38,6 +39,42 @@ config_setting(
flag_values = {"@cuda//:valid_toolchain_found": "True"},
)

# NOTE: Functional with platform_alias only.
string_flag(
name = "version",
build_setting_default = "",
)

# NOTE: Functional with platform_alias only.
string_flag(
name = "target_platform",
build_setting_default = "linux-x86_64",
values = SUPPORTED_PLATFORMS,
)

[
config_setting(
name = "target_platform_is_{}".format(platform.replace("-", "_")),
flag_values = {":target_platform": platform},
)
for platform in SUPPORTED_PLATFORMS
]

# NOTE: Functional with platform_alias only.
string_flag(
name = "exec_platform",
build_setting_default = "linux-x86_64",
values = SUPPORTED_PLATFORMS,
)

[
config_setting(
name = "exec_platform_is_{}".format(platform.replace("-", "_")),
flag_values = {":exec_platform": platform},
)
for platform in SUPPORTED_PLATFORMS
]

# Command line flag to specify the list of CUDA architectures to compile for.
#
# Provides CudaArchsInfo of the list of archs to build.
Expand Down
3 changes: 3 additions & 0 deletions cuda/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Core rules for building CUDA projects.
"""

load("//cuda/private:defs.bzl", _requires_cuda = "requires_cuda")
load("//cuda/private:errors.bzl", _unsupported_cuda_platform = "unsupported_cuda_platform", _unsupported_cuda_version = "unsupported_cuda_version")
load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary")
load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test")
load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows")
Expand Down Expand Up @@ -47,3 +48,5 @@ if_windows = _if_windows
cc_import_versioned_sos = _cc_import_versioned_sos

requires_cuda = _requires_cuda
unsupported_cuda_version = _unsupported_cuda_version
unsupported_cuda_platform = _unsupported_cuda_platform
20 changes: 18 additions & 2 deletions cuda/dummy/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@rules_cc//cc:defs.bzl", "cc_binary")
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")

package(default_visibility = ["//visibility:public"])

Expand All @@ -8,13 +8,22 @@ cc_binary(
defines = ["TOOLNAME=nvcc"],
)

cc_binary(
name = "cicc",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=cicc"],
)

cc_binary(
name = "nvlink",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=nvlink"],
)

exports_files(["link.stub"])
exports_files([
"link.stub",
"libdevice.10.bc",
])

cc_binary(
name = "bin2c",
Expand All @@ -28,6 +37,13 @@ cc_binary(
defines = ["TOOLNAME=fatbinary"],
)

# Empty cc_library that provides CcInfo for components not available in this CUDA version.
cc_library(
name = "dummy",
srcs = [],
hdrs = [],
)

cc_binary(
name = "ptxas",
srcs = ["dummy.cpp"],
Expand Down
1 change: 1 addition & 0 deletions cuda/dummy/libdevice.10.bc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#error libdevice.10.bc of cuda toolkit does not exist
143 changes: 122 additions & 21 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Entry point for extensions used by bzlmod."""

load("//cuda:platform_alias_extension.bzl", "platform_alias_repo")
load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS")
load("//cuda/private:redist_json_helper.bzl", "redist_json_helper")
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit")

Expand Down Expand Up @@ -53,6 +55,9 @@ cuda_redist_json_tag = tag_class(attrs = {
"URLs are tried in order until one succeeds, so you should list local mirrors first. " +
"If all downloads fail, the rule will fail.",
),
"platforms": attr.string_list(
doc = "A list of platforms to generate components for.",
),
"version": attr.string(
doc = "Generate a URL by using the specified version." +
"This URL will be tried after all URLs specified in the `urls` attribute.",
Expand Down Expand Up @@ -92,20 +97,63 @@ def _find_modules(module_ctx):
def _module_tag_to_dict(t):
return {attr: getattr(t, attr) for attr in dir(t)}

def _redist_json_impl(module_ctx, attr):
def _platform_repos_attr(platform):
return platform.replace("-", "_") + "_repos"

def _version_sort_key(version):
prefix = version.split("-", 1)[0]
parts = prefix.split(".")
if all([p.isdigit() for p in parts]):
return (1, [int(p) for p in parts], version)
return (0, [], version)

def _component_attrs_match(existing, current):
for key, value in current.items():
if key == "name":
continue
if key not in existing or existing[key] != value:
return False
for key in existing.keys():
if key == "name":
continue
if key not in current:
return False
return True

def _register_redist_components(module_ctx, attr, component_entries):
url, json_object = redist_json_helper.get(module_ctx, attr)
redist_ver = redist_json_helper.get_redist_version(module_ctx, attr, json_object)
component_specs = redist_json_helper.collect_specs(module_ctx, attr, json_object, url)

mapping = {}
for spec in component_specs:
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
mapping[spec["component_name"]] = "@" + repo_name
for platform in attr.platforms:
component_specs = redist_json_helper.collect_specs(module_ctx, attr, platform, json_object, url)
for spec in component_specs:
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)

attr = {key: value for key, value in spec.items()}
attr["name"] = repo_name
cuda_component(**attr)
return redist_ver, mapping
component_attr = {key: value for key, value in spec.items()}
component_repo_name = repo_name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_")
component_attr["name"] = component_repo_name

dedupe_key = (spec["component_name"], platform, redist_ver)
existing_entry = component_entries.get(dedupe_key)
existing_attr = existing_entry["component_attr"] if existing_entry else None
if existing_attr == None:
cuda_component(**component_attr)
component_entries[dedupe_key] = {
"component_name": spec["component_name"],
"platform": platform,
"redist_version": redist_ver,
"repo_name": repo_name,
"generated_repo_name": component_repo_name,
"component_attr": component_attr,
}
elif not _component_attrs_match(existing_attr, component_attr):
fail(("Conflicting CUDA component definition for {} on {} at version {}. " +
"Use distinct component versions when registries are not identical.").format(
spec["component_name"],
platform,
redist_ver,
))
return redist_ver

def _impl(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
Expand All @@ -125,30 +173,83 @@ def _impl(module_ctx):
for component in components:
cuda_component(**_module_tag_to_dict(component))

if len(redist_jsons) > 1:
fail("Using multiple cuda.redist_json is not supported yet.")

redist_version = None
components_mapping = None
redist_versions = []
component_entries = {}
for redist_json in redist_jsons:
redist_version, components_mapping = _redist_json_impl(module_ctx, redist_json)
redist_version = _register_redist_components(module_ctx, redist_json, component_entries)
if redist_version not in redist_versions:
redist_versions.append(redist_version)

if len(component_entries) > 0:
components_mapping = {}
redist_components_mapping = {}
versioned_repos = {}
for entry in component_entries.values():
component_name = entry["component_name"]
platform = entry["platform"]
redist_version = entry["redist_version"]

redist_components_mapping[component_name] = entry["repo_name"]
if component_name not in versioned_repos:
versioned_repos[component_name] = {}
if platform not in versioned_repos[component_name]:
versioned_repos[component_name][platform] = {}
versioned_repos[component_name][platform][redist_version] = entry["generated_repo_name"]

for component_name in redist_components_mapping.keys():
component_platforms = [
platform
for platform in SUPPORTED_PLATFORMS
if platform in versioned_repos[component_name] and len(versioned_repos[component_name][platform]) > 0
]

# Preserve pre-multi-version behavior for the simple case:
# if there is exactly one concrete repo, wire toolkit mapping directly.
if len(redist_versions) == 1 and len(component_platforms) == 1:
only_platform = component_platforms[0]
only_version = redist_versions[0]
only_repo = versioned_repos[component_name][only_platform].get(only_version)
if only_repo:
components_mapping[component_name] = "@" + only_repo
continue

# Build dictionaries mapping versions to repo names for each platform.
platform_repo_kwargs = {}
for platform in SUPPORTED_PLATFORMS:
platform_repo_kwargs[_platform_repos_attr(platform)] = {
ver: versioned_repos[component_name][platform][ver]
for ver in redist_versions
if platform in versioned_repos[component_name] and ver in versioned_repos[component_name][platform]
}

platform_alias_repo(
name = redist_components_mapping[component_name],
component_name = component_name,
versions = redist_versions,
**platform_repo_kwargs
)
components_mapping[component_name] = "@" + redist_components_mapping[component_name]

registrations = {}
for toolkit in toolkits:
if toolkit.name in registrations.keys():
if toolkit.toolkit_path == registrations[toolkit.name].toolkit_path:
# No problem to register a matching toolkit twice
continue
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(toolkit.name, toolkit.toolkit_path, registrations[toolkit.name].toolkit_path))
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(
toolkit.name,
toolkit.toolkit_path,
registrations[toolkit.name].toolkit_path,
))
else:
registrations[toolkit.name] = toolkit

if len(registrations) > 1:
fail("multiple cuda.toolkit is not supported")

for _, toolkit in registrations.items():
if components_mapping != None:
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = redist_version)
# Always use the maximum version so the toolkit includes all components.
# Components that don't exist in older versions will fall back to dummy.
toolkit_version = sorted(redist_versions, key = _version_sort_key)[-1]
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = toolkit_version)
else:
cuda_toolkit(**_module_tag_to_dict(toolkit))

Expand Down
Loading
Loading