@@ -72,6 +72,10 @@ cuda_toolkit_tag = tag_class(attrs = {
7272 "nvcc_version" : attr .string (
7373 doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted." ,
7474 ),
75+ "redist_json_name" : attr .string (
76+ doc = "Name of the redist_json tag whose components this toolkit should use. " +
77+ "If omitted and exactly one redist_json is declared, it is used automatically." ,
78+ ),
7579})
7680
7781def _find_modules (module_ctx ):
@@ -125,32 +129,51 @@ def _impl(module_ctx):
125129 for component in components :
126130 cuda_component (** _module_tag_to_dict (component ))
127131
128- if len (redist_jsons ) > 1 :
129- fail ("Using multiple cuda.redist_json is not supported yet." )
130-
131- redist_version = None
132- components_mapping = None
132+ # Process each redist_json, keyed by name.
133+ redist_results = {}
133134 for redist_json in redist_jsons :
135+ if redist_json .name in redist_results :
136+ fail ("Multiple redist_json tags declared with the same name '{}'" .format (redist_json .name ))
134137 redist_version , components_mapping = _redist_json_impl (module_ctx , redist_json )
138+ redist_results [redist_json .name ] = struct (
139+ version = redist_version ,
140+ components_mapping = components_mapping ,
141+ )
135142
143+ # Deduplicate toolkit registrations by name.
136144 registrations = {}
137145 for toolkit in toolkits :
138146 if toolkit .name in registrations .keys ():
139147 if toolkit .toolkit_path == registrations [toolkit .name ].toolkit_path :
140- # No problem to register a matching toolkit twice
141148 continue
142- fail ("Multiple conflicting toolkits declared for name {} ({} and {}" .format (toolkit .name , toolkit .toolkit_path , registrations [toolkit .name ].toolkit_path ))
149+ fail ("Multiple conflicting toolkits declared for name {} ({} and {}" .format (
150+ toolkit .name ,
151+ toolkit .toolkit_path ,
152+ registrations [toolkit .name ].toolkit_path ,
153+ ))
143154 else :
144155 registrations [toolkit .name ] = toolkit
145156
146- if len (registrations ) > 1 :
147- fail ("multiple cuda.toolkit is not supported" )
148-
157+ # Instantiate each toolkit, wiring it to the appropriate redist_json.
149158 for _ , toolkit in registrations .items ():
150- if components_mapping != None :
151- cuda_toolkit (name = toolkit .name , components_mapping = components_mapping , version = redist_version )
159+ rj_name = toolkit .redist_json_name
160+ if rj_name :
161+ if rj_name not in redist_results :
162+ fail ("Toolkit '{}' references redist_json_name '{}', but no redist_json with that name was declared" .format (
163+ toolkit .name ,
164+ rj_name ,
165+ ))
166+ rj = redist_results [rj_name ]
167+ cuda_toolkit (name = toolkit .name , components_mapping = rj .components_mapping , version = rj .version )
168+ elif len (redist_results ) == 1 :
169+ # Backward compatible: single redist_json, use it implicitly.
170+ rj = redist_results .values ()[0 ]
171+ cuda_toolkit (name = toolkit .name , components_mapping = rj .components_mapping , version = rj .version )
172+ elif redist_results :
173+ fail ("Multiple redist_json tags declared but toolkit '{}' does not specify redist_json_name" .format (toolkit .name ))
152174 else :
153- cuda_toolkit (** _module_tag_to_dict (toolkit ))
175+ attrs = {k : v for k , v in _module_tag_to_dict (toolkit ).items () if k != "redist_json_name" }
176+ cuda_toolkit (** attrs )
154177
155178toolchain = module_extension (
156179 implementation = _impl ,
0 commit comments