Skip to content

Commit d693d46

Browse files
committed
Allow multiple cuda versions in the same toolchain
1 parent 6fa09c3 commit d693d46

File tree

1 file changed

+37
-13
lines changed

1 file changed

+37
-13
lines changed

cuda/extensions.bzl

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ cuda_toolkit_tag = tag_class(attrs = {
7272
"nvcc_version": attr.string(
7373
doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.",
7474
),
75+
"redist_json_name": attr.string(
76+
doc = "Name of the redist_json tag whose components this toolkit should use. " +
77+
"If omitted and exactly one redist_json is declared, it is used automatically.",
78+
),
7579
})
7680

7781
def _find_modules(module_ctx):
@@ -125,32 +129,52 @@ def _impl(module_ctx):
125129
for component in components:
126130
cuda_component(**_module_tag_to_dict(component))
127131

128-
if len(redist_jsons) > 1:
129-
fail("Using multiple cuda.redist_json is not supported yet.")
130-
131-
redist_version = None
132-
components_mapping = None
132+
# Process each redist_json, keyed by name.
133+
redist_results = {}
133134
for redist_json in redist_jsons:
135+
if redist_json.name in redist_results:
136+
fail("Multiple redist_json tags declared with the same name '{}'".format(redist_json.name))
134137
redist_version, components_mapping = _redist_json_impl(module_ctx, redist_json)
138+
redist_results[redist_json.name] = struct(
139+
version = redist_version,
140+
components_mapping = components_mapping,
141+
)
135142

143+
# Deduplicate toolkit registrations by name.
136144
registrations = {}
137145
for toolkit in toolkits:
138146
if toolkit.name in registrations.keys():
139147
if toolkit.toolkit_path == registrations[toolkit.name].toolkit_path:
140-
# No problem to register a matching toolkit twice
141148
continue
142-
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(toolkit.name, toolkit.toolkit_path, registrations[toolkit.name].toolkit_path))
149+
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(
150+
toolkit.name,
151+
toolkit.toolkit_path,
152+
registrations[toolkit.name].toolkit_path,
153+
))
143154
else:
144155
registrations[toolkit.name] = toolkit
145156

146-
if len(registrations) > 1:
147-
fail("multiple cuda.toolkit is not supported")
148-
157+
# Instantiate each toolkit, wiring it to the appropriate redist_json.
149158
for _, toolkit in registrations.items():
150-
if components_mapping != None:
151-
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = redist_version)
159+
rj_name = toolkit.redist_json_name
160+
if rj_name:
161+
if rj_name not in redist_results:
162+
fail("Toolkit '{}' references redist_json_name '{}', but no redist_json with that name was declared".format(
163+
toolkit.name,
164+
rj_name,
165+
))
166+
rj = redist_results[rj_name]
167+
cuda_toolkit(name = toolkit.name, components_mapping = rj.components_mapping, version = rj.version)
168+
elif len(redist_results) == 1:
169+
# Backward compatible: single redist_json, use it implicitly.
170+
rj = redist_results.values()[0]
171+
cuda_toolkit(name = toolkit.name, components_mapping = rj.components_mapping, version = rj.version)
172+
elif redist_results:
173+
fail("Multiple redist_json tags declared but toolkit '{}' does not specify redist_json_name".format(toolkit.name))
152174
else:
153-
cuda_toolkit(**_module_tag_to_dict(toolkit))
175+
# No redist_json at all — local toolkit path.
176+
attrs = {k: v for k, v in _module_tag_to_dict(toolkit).items() if k != "redist_json_name"}
177+
cuda_toolkit(**attrs)
154178

155179
toolchain = module_extension(
156180
implementation = _impl,

0 commit comments

Comments
 (0)