11"""Entry point for extensions used by bzlmod."""
22
33load ("//cuda:platform_alias_extension.bzl" , "platform_alias_repo" )
4+ load ("//cuda/private:platforms.bzl" , "SUPPORTED_PLATFORMS" )
45load ("//cuda/private:redist_json_helper.bzl" , "redist_json_helper" )
56load ("//cuda/private:repositories.bzl" , "cuda_component" , "cuda_toolkit" )
67
@@ -96,6 +97,16 @@ def _find_modules(module_ctx):
9697def _module_tag_to_dict (t ):
9798 return {attr : getattr (t , attr ) for attr in dir (t )}
9899
100+ def _platform_repos_attr (platform ):
101+ return platform .replace ("-" , "_" ) + "_repos"
102+
103+ def _version_sort_key (version ):
104+ prefix = version .split ("-" , 1 )[0 ]
105+ parts = prefix .split ("." )
106+ if all ([p .isdigit () for p in parts ]):
107+ return (1 , [int (p ) for p in parts ], version )
108+ return (0 , [], version )
109+
99110def _component_attrs_match (existing , current ):
100111 for key , value in current .items ():
101112 if key == "name" :
@@ -109,36 +120,43 @@ def _component_attrs_match(existing, current):
109120 return False
110121 return True
111122
112- def _redist_json_impl (module_ctx , attr , generated_components ):
123+ def _component_entry_key (component_name , platform , redist_ver ):
124+ return "{}|{}|{}" .format (component_name , platform , redist_ver )
125+
126+ def _register_redist_components (module_ctx , attr , component_entries ):
113127 url , json_object = redist_json_helper .get (module_ctx , attr )
114128 redist_ver = redist_json_helper .get_redist_version (module_ctx , attr , json_object )
115129
116- platform_mapping = {}
117130 for platform in attr .platforms :
118131 component_specs = redist_json_helper .collect_specs (module_ctx , attr , platform , json_object , url )
119- mapping = {}
120132 for spec in component_specs :
121133 repo_name = redist_json_helper .get_repo_name (module_ctx , spec )
122- mapping [spec ["component_name" ]] = repo_name
123134
124135 component_attr = {key : value for key , value in spec .items ()}
125136 component_repo_name = repo_name + "_" + platform .replace ("-" , "_" ) + "_" + redist_ver .replace ("." , "_" )
126137 component_attr ["name" ] = component_repo_name
127138
128- dedupe_key = "{}|{}|{}" .format (spec ["component_name" ], platform , redist_ver )
129- existing_attr = generated_components .get (dedupe_key )
139+ dedupe_key = _component_entry_key (spec ["component_name" ], platform , redist_ver )
140+ existing_entry = component_entries .get (dedupe_key )
141+ existing_attr = existing_entry ["component_attr" ] if existing_entry else None
130142 if existing_attr == None :
131143 cuda_component (** component_attr )
132- generated_components [dedupe_key ] = component_attr
144+ component_entries [dedupe_key ] = {
145+ "component_name" : spec ["component_name" ],
146+ "platform" : platform ,
147+ "redist_version" : redist_ver ,
148+ "repo_name" : repo_name ,
149+ "generated_repo_name" : component_repo_name ,
150+ "component_attr" : component_attr ,
151+ }
133152 elif not _component_attrs_match (existing_attr , component_attr ):
134153 fail (("Conflicting CUDA component definition for {} on {} at version {}. " +
135154 "Use distinct component versions when registries are not identical." ).format (
136155 spec ["component_name" ],
137156 platform ,
138157 redist_ver ,
139158 ))
140- platform_mapping [platform ] = mapping
141- return redist_ver , platform_mapping
159+ return redist_ver
142160
143161def _impl (module_ctx ):
144162 # Toolchain configuration is only allowed in the root module, or in rules_cuda.
@@ -158,47 +176,62 @@ def _impl(module_ctx):
158176 for component in components :
159177 cuda_component (** _module_tag_to_dict (component ))
160178
161- redist_version = None
162179 components_mapping = None
163180 redist_versions = []
164- redist_components_mapping = {}
165-
166- # Track all versioned repositories for each component and platform.
167- versioned_repos = {}
168- generated_components = {}
181+ component_entries = {}
169182 for redist_json in redist_jsons :
170- components_mapping = {}
171- redist_version , platform_mapping = _redist_json_impl (module_ctx , redist_json , generated_components )
183+ redist_version = _register_redist_components (module_ctx , redist_json , component_entries )
172184 if redist_version not in redist_versions :
173185 redist_versions .append (redist_version )
174- for platform in platform_mapping .keys ():
175- for component_name , repo_name in platform_mapping [platform ].items ():
176- redist_components_mapping [component_name ] = repo_name
177-
178- # Track the versioned repo name for this component/platform/version.
179- if component_name not in versioned_repos :
180- versioned_repos [component_name ] = {}
181- if platform not in versioned_repos [component_name ]:
182- versioned_repos [component_name ][platform ] = {}
183- versioned_repos [component_name ][platform ][redist_version ] = repo_name + "_" + platform .replace ("-" , "_" ) + "_" + redist_version .replace ("." , "_" )
184-
185- for component_name in redist_components_mapping .keys ():
186- # Build dictionaries mapping versions to repo names for each platform.
187- x86_64_repos = {ver : versioned_repos [component_name ]["linux-x86_64" ][ver ] for ver in redist_versions if "linux-x86_64" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["linux-x86_64" ]}
188- windows_x86_64_repos = {ver : versioned_repos [component_name ]["windows-x86_64" ][ver ] for ver in redist_versions if "windows-x86_64" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["windows-x86_64" ]}
189- aarch64_repos = {ver : versioned_repos [component_name ]["linux-aarch64" ][ver ] for ver in redist_versions if "linux-aarch64" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["linux-aarch64" ]}
190- sbsa_repos = {ver : versioned_repos [component_name ]["linux-sbsa" ][ver ] for ver in redist_versions if "linux-sbsa" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["linux-sbsa" ]}
191-
192- platform_alias_repo (
193- name = redist_components_mapping [component_name ],
194- component_name = component_name ,
195- linux_x86_64_repos = x86_64_repos ,
196- windows_x86_64_repos = windows_x86_64_repos ,
197- linux_aarch64_repos = aarch64_repos ,
198- linux_sbsa_repos = sbsa_repos ,
199- versions = redist_versions ,
200- )
201- components_mapping [component_name ] = "@" + redist_components_mapping [component_name ]
186+
187+ if len (component_entries ) > 0 :
188+ components_mapping = {}
189+ redist_components_mapping = {}
190+ versioned_repos = {}
191+ for entry in component_entries .values ():
192+ component_name = entry ["component_name" ]
193+ platform = entry ["platform" ]
194+ redist_version = entry ["redist_version" ]
195+
196+ redist_components_mapping [component_name ] = entry ["repo_name" ]
197+ if component_name not in versioned_repos :
198+ versioned_repos [component_name ] = {}
199+ if platform not in versioned_repos [component_name ]:
200+ versioned_repos [component_name ][platform ] = {}
201+ versioned_repos [component_name ][platform ][redist_version ] = entry ["generated_repo_name" ]
202+
203+ for component_name in redist_components_mapping .keys ():
204+ component_platforms = [
205+ platform
206+ for platform in SUPPORTED_PLATFORMS
207+ if platform in versioned_repos [component_name ] and len (versioned_repos [component_name ][platform ]) > 0
208+ ]
209+ # Preserve pre-multi-version behavior for the simple case:
210+ # if there is exactly one concrete repo, wire toolkit mapping directly.
211+ if len (redist_versions ) == 1 and len (component_platforms ) == 1 :
212+ only_platform = component_platforms [0 ]
213+ only_version = redist_versions [0 ]
214+ only_repo = versioned_repos [component_name ][only_platform ].get (only_version )
215+ if only_repo :
216+ components_mapping [component_name ] = "@" + only_repo
217+ continue
218+
219+ # Build dictionaries mapping versions to repo names for each platform.
220+ platform_repo_kwargs = {}
221+ for platform in SUPPORTED_PLATFORMS :
222+ platform_repo_kwargs [_platform_repos_attr (platform )] = {
223+ ver : versioned_repos [component_name ][platform ][ver ]
224+ for ver in redist_versions
225+ if platform in versioned_repos [component_name ] and ver in versioned_repos [component_name ][platform ]
226+ }
227+
228+ platform_alias_repo (
229+ name = redist_components_mapping [component_name ],
230+ component_name = component_name ,
231+ versions = redist_versions ,
232+ ** platform_repo_kwargs
233+ )
234+ components_mapping [component_name ] = "@" + redist_components_mapping [component_name ]
202235
203236 registrations = {}
204237 for toolkit in toolkits :
@@ -217,13 +250,7 @@ def _impl(module_ctx):
217250 if components_mapping != None :
218251 # Always use the maximum version so the toolkit includes all components.
219252 # Components that don't exist in older versions will fall back to dummy.
220- toolkit_version = redist_versions [0 ]
221- for ver in redist_versions :
222- ver_parts = [int (x ) for x in ver .split ("." )]
223- tv_parts = [int (x ) for x in toolkit_version .split ("." )]
224- if ver_parts > tv_parts :
225- toolkit_version = ver
226-
253+ toolkit_version = sorted (redist_versions , key = _version_sort_key )[- 1 ]
227254 cuda_toolkit (name = toolkit .name , components_mapping = components_mapping , version = toolkit_version )
228255 else :
229256 cuda_toolkit (** _module_tag_to_dict (toolkit ))
0 commit comments