11"""Entry point for extensions used by bzlmod."""
22
3+ load ("//cuda:platform_alias_extension.bzl" , "platform_alias_repo" )
34load ("//cuda/private:redist_json_helper.bzl" , "redist_json_helper" )
45load ("//cuda/private:repositories.bzl" , "cuda_component" , "cuda_toolkit" )
56
@@ -53,6 +54,9 @@ cuda_redist_json_tag = tag_class(attrs = {
5354 "URLs are tried in order until one succeeds, so you should list local mirrors first. " +
5455 "If all downloads fail, the rule will fail." ,
5556 ),
57+ "platforms" : attr .string_list (
58+ doc = "A list of platforms to generate components for." ,
59+ ),
5660 "version" : attr .string (
5761 doc = "Generate a URL by using the specified version." +
5862 "This URL will be tried after all URLs specified in the `urls` attribute." ,
@@ -72,6 +76,10 @@ cuda_toolkit_tag = tag_class(attrs = {
7276 "nvcc_version" : attr .string (
7377 doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted." ,
7478 ),
79+ "redist_json_name" : attr .string (
80+ doc = "Name of the redist_json tag whose components this toolkit should use. " +
81+ "If omitted and exactly one redist_json is declared, it is used automatically." ,
82+ ),
7583})
7684
7785def _find_modules (module_ctx ):
@@ -95,17 +103,21 @@ def _module_tag_to_dict(t):
95103def _redist_json_impl (module_ctx , attr ):
96104 url , json_object = redist_json_helper .get (module_ctx , attr )
97105 redist_ver = redist_json_helper .get_redist_version (module_ctx , attr , json_object )
98- component_specs = redist_json_helper .collect_specs (module_ctx , attr , json_object , url )
99-
100- mapping = {}
101- for spec in component_specs :
102- repo_name = redist_json_helper .get_repo_name (module_ctx , spec )
103- mapping [spec ["component_name" ]] = "@" + repo_name
104106
105- attr = {key : value for key , value in spec .items ()}
106- attr ["name" ] = repo_name
107- cuda_component (** attr )
108- return redist_ver , mapping
107+ platform_mapping = {}
108+ for platform in attr .platforms :
109+ component_specs = redist_json_helper .collect_specs (module_ctx , attr , platform , json_object , url )
110+ mapping = {}
111+ for spec in component_specs :
112+ repo_name = redist_json_helper .get_repo_name (module_ctx , spec )
113+ mapping [spec ["component_name" ]] = repo_name
114+
115+ component_attr = {key : value for key , value in spec .items ()}
116+ component_repo_name = repo_name + "_" + attr .name + "_" + platform .replace ("-" , "_" ) + "_" + redist_ver .replace ("." , "_" )
117+ component_attr ["name" ] = component_repo_name
118+ cuda_component (** component_attr )
119+ platform_mapping [platform ] = mapping
120+ return redist_ver , platform_mapping
109121
110122def _impl (module_ctx ):
111123 # Toolchain configuration is only allowed in the root module, or in rules_cuda.
@@ -125,32 +137,74 @@ def _impl(module_ctx):
125137 for component in components :
126138 cuda_component (** _module_tag_to_dict (component ))
127139
128- if len (redist_jsons ) > 1 :
129- fail ("Using multiple cuda.redist_json is not supported yet." )
130-
131140 redist_version = None
132141 components_mapping = None
142+ redist_versions = []
143+ redist_components_mapping = {}
144+
145+ # Track all versioned repositories for each component and platform.
146+ versioned_repos = {}
133147 for redist_json in redist_jsons :
134- redist_version , components_mapping = _redist_json_impl (module_ctx , redist_json )
148+ components_mapping = {}
149+ redist_version , platform_mapping = _redist_json_impl (module_ctx , redist_json )
150+ if redist_version not in redist_versions :
151+ redist_versions .append (redist_version )
152+ for platform in platform_mapping .keys ():
153+ for component_name , repo_name in platform_mapping [platform ].items ():
154+ redist_components_mapping [component_name ] = repo_name
155+
156+ # Track the versioned repo name for this component/platform/version.
157+ if component_name not in versioned_repos :
158+ versioned_repos [component_name ] = {}
159+ if platform not in versioned_repos [component_name ]:
160+ versioned_repos [component_name ][platform ] = {}
161+ versioned_repos [component_name ][platform ][redist_version ] = repo_name + "_" + redist_json .name + "_" + platform .replace ("-" , "_" ) + "_" + redist_version .replace ("." , "_" )
162+
163+ for component_name in redist_components_mapping .keys ():
164+ # Build dictionaries mapping versions to repo names for each platform.
165+ x86_64_repos = {ver : versioned_repos [component_name ]["linux-x86_64" ][ver ] for ver in redist_versions if "linux-x86_64" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["linux-x86_64" ]}
166+ aarch64_repos = {ver : versioned_repos [component_name ]["linux-aarch64" ][ver ] for ver in redist_versions if "linux-aarch64" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["linux-aarch64" ]}
167+ sbsa_repos = {ver : versioned_repos [component_name ]["linux-sbsa" ][ver ] for ver in redist_versions if "linux-sbsa" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["linux-sbsa" ]}
168+
169+ platform_alias_repo (
170+ name = redist_components_mapping [component_name ],
171+ repo_name = redist_components_mapping [component_name ],
172+ component_name = component_name ,
173+ linux_x86_64_repos = x86_64_repos ,
174+ linux_aarch64_repos = aarch64_repos ,
175+ linux_sbsa_repos = sbsa_repos ,
176+ versions = redist_versions ,
177+ )
178+ components_mapping [component_name ] = "@" + redist_components_mapping [component_name ]
135179
136180 registrations = {}
137181 for toolkit in toolkits :
138182 if toolkit .name in registrations .keys ():
139183 if toolkit .toolkit_path == registrations [toolkit .name ].toolkit_path :
140- # No problem to register a matching toolkit twice
141184 continue
142- fail ("Multiple conflicting toolkits declared for name {} ({} and {}" .format (toolkit .name , toolkit .toolkit_path , registrations [toolkit .name ].toolkit_path ))
185+ fail ("Multiple conflicting toolkits declared for name {} ({} and {}" .format (
186+ toolkit .name ,
187+ toolkit .toolkit_path ,
188+ registrations [toolkit .name ].toolkit_path ,
189+ ))
143190 else :
144191 registrations [toolkit .name ] = toolkit
145192
146- if len (registrations ) > 1 :
147- fail ("multiple cuda.toolkit is not supported" )
148-
149193 for _ , toolkit in registrations .items ():
150194 if components_mapping != None :
151- cuda_toolkit (name = toolkit .name , components_mapping = components_mapping , version = redist_version )
195+ # Always use the maximum version so the toolkit includes all components.
196+ # Components that don't exist in older versions will fall back to dummy.
197+ toolkit_version = redist_versions [0 ]
198+ for ver in redist_versions :
199+ ver_parts = [int (x ) for x in ver .split ("." )]
200+ tv_parts = [int (x ) for x in toolkit_version .split ("." )]
201+ if ver_parts > tv_parts :
202+ toolkit_version = ver
203+
204+ cuda_toolkit (name = toolkit .name , components_mapping = components_mapping , version = toolkit_version )
152205 else :
153- cuda_toolkit (** _module_tag_to_dict (toolkit ))
206+ attrs = {k : v for k , v in _module_tag_to_dict (toolkit ).items () if k != "redist_json_name" }
207+ cuda_toolkit (** attrs )
154208
155209toolchain = module_extension (
156210 implementation = _impl ,
0 commit comments