Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ jobs:

# Override cuda redist_json rule version for CI
- run: echo "CUDA_REDIST_VERSION_OVERRIDE=${{ matrix.cases.cuda-version }}" >> $GITHUB_ENV
if: ${{ !startsWith(matrix.cases.os, 'windows') }}
if: ${{ !startsWith(matrix.cases.os, 'windows') && matrix.cases.source == 'nvidia' }}
- run: echo "CUDA_REDIST_VERSION_OVERRIDE=${{ matrix.cases.cuda-version }}" >> $env:GITHUB_ENV
if: ${{ startsWith(matrix.cases.os, 'windows') }}
if: ${{ startsWith(matrix.cases.os, 'windows') && matrix.cases.source == 'nvidia' }}

# Use Bazel with version specified in .bazelversion
- run: echo "USE_BAZEL_VERSION=$(cat .bazelversion)" >> $GITHUB_ENV
Expand Down
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,46 @@ use_repo(cuda, "cuda")

`rules_cc` provides the C++ toolchain dependency for `rules_cuda`; in Bzlmod, the compatibility repository is handled by `rules_cc` itself.

#### Multi-version hermetic toolchains (Bzlmod)

To select CUDA versions and platforms at build time, define multiple `cuda.redist_json` entries and a single `cuda.toolkit`.
Then use `@rules_cuda//cuda:version`, `@rules_cuda//cuda:exec_platform`, and `@rules_cuda//cuda:target_platform` flags.

```starlark
cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")

cuda.redist_json(
name = "cuda_13_0_2",
version = "13.0.2",
platforms = [
"linux-x86_64",
"linux-sbsa",
],
)
cuda.redist_json(
name = "cuda_13_0_0",
version = "13.0.0",
platforms = [
"linux-x86_64",
"linux-sbsa",
],
)

cuda.toolkit(name = "cuda")
use_repo(cuda, "cuda")
```

Example `.bazelrc` entries:

```
build --@rules_cuda//cuda:exec_platform=linux-x86_64
build --@rules_cuda//cuda:target_platform=linux-x86_64
build --@rules_cuda//cuda:version=13.0.0
```

Note: In Bzlmod, `platforms` is required for `cuda.redist_json` because module extensions don't have access to host OS/arch
information, so the platforms must be declared explicitly.

<details>
<summary>Traditional WORKSPACE approach</summary>

Expand Down
34 changes: 34 additions & 0 deletions cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ load(
"bool_flag",
"string_flag",
)
load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS")
load("//cuda/private:rules/flags.bzl", "cuda_archs_flag", "repeatable_string_flag")

package(default_visibility = ["//visibility:public"])
Expand Down Expand Up @@ -38,6 +39,39 @@ config_setting(
flag_values = {"@cuda//:valid_toolchain_found": "True"},
)

string_flag(
name = "version",
build_setting_default = "13.0.0",
)

string_flag(
name = "target_platform",
build_setting_default = "linux-x86_64",
values = SUPPORTED_PLATFORMS,
)

[
config_setting(
name = "target_platform_is_{}".format(platform.replace("-", "_")),
flag_values = {":target_platform": platform},
)
for platform in SUPPORTED_PLATFORMS
]

string_flag(
name = "exec_platform",
build_setting_default = "linux-x86_64",
values = SUPPORTED_PLATFORMS,
)

[
config_setting(
name = "exec_platform_is_{}".format(platform.replace("-", "_")),
flag_values = {":exec_platform": platform},
)
for platform in SUPPORTED_PLATFORMS
]

# Command line flag to specify the list of CUDA architectures to compile for.
#
# Provides CudaArchsInfo of the list of archs to build.
Expand Down
3 changes: 3 additions & 0 deletions cuda/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Core rules for building CUDA projects.
"""

load("//cuda/private:defs.bzl", _requires_cuda = "requires_cuda")
load("//cuda/private:errors.bzl", _unsupported_cuda_platform = "unsupported_cuda_platform", _unsupported_cuda_version = "unsupported_cuda_version")
load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary")
load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test")
load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows")
Expand Down Expand Up @@ -47,3 +48,5 @@ if_windows = _if_windows
cc_import_versioned_sos = _cc_import_versioned_sos

requires_cuda = _requires_cuda
unsupported_cuda_version = _unsupported_cuda_version
unsupported_cuda_platform = _unsupported_cuda_platform
20 changes: 18 additions & 2 deletions cuda/dummy/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@rules_cc//cc:defs.bzl", "cc_binary")
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")

package(default_visibility = ["//visibility:public"])

Expand All @@ -8,13 +8,22 @@ cc_binary(
defines = ["TOOLNAME=nvcc"],
)

cc_binary(
name = "cicc",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=cicc"],
)

cc_binary(
name = "nvlink",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=nvlink"],
)

exports_files(["link.stub"])
exports_files([
"link.stub",
"libdevice.10.bc",
])

cc_binary(
name = "bin2c",
Expand All @@ -28,6 +37,13 @@ cc_binary(
defines = ["TOOLNAME=fatbinary"],
)

# Empty cc_library that provides CcInfo for components not available in this CUDA version.
cc_library(
name = "dummy",
srcs = [],
hdrs = [],
)

cc_binary(
name = "ptxas",
srcs = ["dummy.cpp"],
Expand Down
1 change: 1 addition & 0 deletions cuda/dummy/libdevice.10.bc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#error libdevice.10.bc of cuda toolkit does not exist
118 changes: 97 additions & 21 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Entry point for extensions used by bzlmod."""

load("//cuda:platform_alias_extension.bzl", "platform_alias_repo")
load("//cuda/private:redist_json_helper.bzl", "redist_json_helper")
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit")

Expand Down Expand Up @@ -53,6 +54,9 @@ cuda_redist_json_tag = tag_class(attrs = {
"URLs are tried in order until one succeeds, so you should list local mirrors first. " +
"If all downloads fail, the rule will fail.",
),
"platforms": attr.string_list(
doc = "A list of platforms to generate components for.",
),
"version": attr.string(
doc = "Generate a URL by using the specified version." +
"This URL will be tried after all URLs specified in the `urls` attribute.",
Expand Down Expand Up @@ -92,20 +96,49 @@ def _find_modules(module_ctx):
def _module_tag_to_dict(t):
return {attr: getattr(t, attr) for attr in dir(t)}

def _redist_json_impl(module_ctx, attr):
def _component_attrs_match(existing, current):
for key, value in current.items():
if key == "name":
continue
if key not in existing or existing[key] != value:
return False
for key in existing.keys():
if key == "name":
continue
if key not in current:
return False
return True

def _redist_json_impl(module_ctx, attr, generated_components):
url, json_object = redist_json_helper.get(module_ctx, attr)
redist_ver = redist_json_helper.get_redist_version(module_ctx, attr, json_object)
component_specs = redist_json_helper.collect_specs(module_ctx, attr, json_object, url)

mapping = {}
for spec in component_specs:
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
mapping[spec["component_name"]] = "@" + repo_name

attr = {key: value for key, value in spec.items()}
attr["name"] = repo_name
cuda_component(**attr)
return redist_ver, mapping
platform_mapping = {}
for platform in attr.platforms:
component_specs = redist_json_helper.collect_specs(module_ctx, attr, platform, json_object, url)
mapping = {}
for spec in component_specs:
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
mapping[spec["component_name"]] = repo_name

component_attr = {key: value for key, value in spec.items()}
component_repo_name = repo_name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_")
component_attr["name"] = component_repo_name

dedupe_key = "{}|{}|{}".format(spec["component_name"], platform, redist_ver)
existing_attr = generated_components.get(dedupe_key)
if existing_attr == None:
cuda_component(**component_attr)
generated_components[dedupe_key] = component_attr
elif not _component_attrs_match(existing_attr, component_attr):
fail(("Conflicting CUDA component definition for {} on {} at version {}. " +
"Use distinct component versions when registries are not identical.").format(
spec["component_name"],
platform,
redist_ver,
))
platform_mapping[platform] = mapping
return redist_ver, platform_mapping

def _impl(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
Expand All @@ -125,30 +158,73 @@ def _impl(module_ctx):
for component in components:
cuda_component(**_module_tag_to_dict(component))

if len(redist_jsons) > 1:
fail("Using multiple cuda.redist_json is not supported yet.")

redist_version = None
components_mapping = None
redist_versions = []
redist_components_mapping = {}

# Track all versioned repositories for each component and platform.
versioned_repos = {}
generated_components = {}
for redist_json in redist_jsons:
redist_version, components_mapping = _redist_json_impl(module_ctx, redist_json)
components_mapping = {}
redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json, generated_components)
if redist_version not in redist_versions:
redist_versions.append(redist_version)
for platform in platform_mapping.keys():
for component_name, repo_name in platform_mapping[platform].items():
redist_components_mapping[component_name] = repo_name

# Track the versioned repo name for this component/platform/version.
if component_name not in versioned_repos:
versioned_repos[component_name] = {}
if platform not in versioned_repos[component_name]:
versioned_repos[component_name][platform] = {}
versioned_repos[component_name][platform][redist_version] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_")

for component_name in redist_components_mapping.keys():
# Build dictionaries mapping versions to repo names for each platform.
x86_64_repos = {ver: versioned_repos[component_name]["linux-x86_64"][ver] for ver in redist_versions if "linux-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-x86_64"]}
windows_x86_64_repos = {ver: versioned_repos[component_name]["windows-x86_64"][ver] for ver in redist_versions if "windows-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["windows-x86_64"]}
aarch64_repos = {ver: versioned_repos[component_name]["linux-aarch64"][ver] for ver in redist_versions if "linux-aarch64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-aarch64"]}
sbsa_repos = {ver: versioned_repos[component_name]["linux-sbsa"][ver] for ver in redist_versions if "linux-sbsa" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-sbsa"]}

platform_alias_repo(
name = redist_components_mapping[component_name],
component_name = component_name,
linux_x86_64_repos = x86_64_repos,
windows_x86_64_repos = windows_x86_64_repos,
linux_aarch64_repos = aarch64_repos,
linux_sbsa_repos = sbsa_repos,
versions = redist_versions,
)
components_mapping[component_name] = "@" + redist_components_mapping[component_name]

registrations = {}
for toolkit in toolkits:
if toolkit.name in registrations.keys():
if toolkit.toolkit_path == registrations[toolkit.name].toolkit_path:
# No problem to register a matching toolkit twice
continue
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(toolkit.name, toolkit.toolkit_path, registrations[toolkit.name].toolkit_path))
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(
toolkit.name,
toolkit.toolkit_path,
registrations[toolkit.name].toolkit_path,
))
else:
registrations[toolkit.name] = toolkit

if len(registrations) > 1:
fail("multiple cuda.toolkit is not supported")

for _, toolkit in registrations.items():
if components_mapping != None:
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = redist_version)
# Always use the maximum version so the toolkit includes all components.
# Components that don't exist in older versions will fall back to dummy.
toolkit_version = redist_versions[0]
for ver in redist_versions:
ver_parts = [int(x) for x in ver.split(".")]
tv_parts = [int(x) for x in toolkit_version.split(".")]
if ver_parts > tv_parts:
toolkit_version = ver

cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = toolkit_version)
else:
cuda_toolkit(**_module_tag_to_dict(toolkit))

Expand Down
Loading
Loading