11"""Entry point for extensions used by bzlmod."""
22
33load ("//cuda:platform_alias_extension.bzl" , "platform_alias_repo" )
4+ load ("//cuda/private:platforms.bzl" , "SUPPORTED_PLATFORMS" )
45load ("//cuda/private:redist_json_helper.bzl" , "redist_json_helper" )
56load ("//cuda/private:repositories.bzl" , "cuda_component" , "cuda_toolkit" )
67
@@ -96,6 +97,9 @@ def _find_modules(module_ctx):
9697def _module_tag_to_dict (t ):
9798 return {attr : getattr (t , attr ) for attr in dir (t )}
9899
100+ def _platform_repos_attr (platform ):
101+ return platform .replace ("-" , "_" ) + "_repos"
102+
99103def _component_attrs_match (existing , current ):
100104 for key , value in current .items ():
101105 if key == "name" :
@@ -109,36 +113,43 @@ def _component_attrs_match(existing, current):
109113 return False
110114 return True
111115
112- def _redist_json_impl (module_ctx , attr , generated_components ):
116+ def _component_entry_key (component_name , platform , redist_ver ):
117+ return "{}|{}|{}" .format (component_name , platform , redist_ver )
118+
119+ def _register_redist_components (module_ctx , attr , component_entries ):
113120 url , json_object = redist_json_helper .get (module_ctx , attr )
114121 redist_ver = redist_json_helper .get_redist_version (module_ctx , attr , json_object )
115122
116- platform_mapping = {}
117123 for platform in attr .platforms :
118124 component_specs = redist_json_helper .collect_specs (module_ctx , attr , platform , json_object , url )
119- mapping = {}
120125 for spec in component_specs :
121126 repo_name = redist_json_helper .get_repo_name (module_ctx , spec )
122- mapping [spec ["component_name" ]] = repo_name
123127
124128 component_attr = {key : value for key , value in spec .items ()}
125129 component_repo_name = repo_name + "_" + platform .replace ("-" , "_" ) + "_" + redist_ver .replace ("." , "_" )
126130 component_attr ["name" ] = component_repo_name
127131
128- dedupe_key = "{}|{}|{}" .format (spec ["component_name" ], platform , redist_ver )
129- existing_attr = generated_components .get (dedupe_key )
132+ dedupe_key = _component_entry_key (spec ["component_name" ], platform , redist_ver )
133+ existing_entry = component_entries .get (dedupe_key )
134+ existing_attr = existing_entry ["component_attr" ] if existing_entry else None
130135 if existing_attr == None :
131136 cuda_component (** component_attr )
132- generated_components [dedupe_key ] = component_attr
137+ component_entries [dedupe_key ] = {
138+ "component_name" : spec ["component_name" ],
139+ "platform" : platform ,
140+ "redist_version" : redist_ver ,
141+ "repo_name" : repo_name ,
142+ "generated_repo_name" : component_repo_name ,
143+ "component_attr" : component_attr ,
144+ }
133145 elif not _component_attrs_match (existing_attr , component_attr ):
134146 fail (("Conflicting CUDA component definition for {} on {} at version {}. " +
135147 "Use distinct component versions when registries are not identical." ).format (
136148 spec ["component_name" ],
137149 platform ,
138150 redist_ver ,
139151 ))
140- platform_mapping [platform ] = mapping
141- return redist_ver , platform_mapping
152+ return redist_ver
142153
143154def _impl (module_ctx ):
144155 # Toolchain configuration is only allowed in the root module, or in rules_cuda.
@@ -158,47 +169,47 @@ def _impl(module_ctx):
158169 for component in components :
159170 cuda_component (** _module_tag_to_dict (component ))
160171
161- redist_version = None
162172 components_mapping = None
163173 redist_versions = []
164- redist_components_mapping = {}
165-
166- # Track all versioned repositories for each component and platform.
167- versioned_repos = {}
168- generated_components = {}
174+ component_entries = {}
169175 for redist_json in redist_jsons :
170- components_mapping = {}
171- redist_version , platform_mapping = _redist_json_impl (module_ctx , redist_json , generated_components )
176+ redist_version = _register_redist_components (module_ctx , redist_json , component_entries )
172177 if redist_version not in redist_versions :
173178 redist_versions .append (redist_version )
174- for platform in platform_mapping .keys ():
175- for component_name , repo_name in platform_mapping [platform ].items ():
176- redist_components_mapping [component_name ] = repo_name
177-
178- # Track the versioned repo name for this component/platform/version.
179- if component_name not in versioned_repos :
180- versioned_repos [component_name ] = {}
181- if platform not in versioned_repos [component_name ]:
182- versioned_repos [component_name ][platform ] = {}
183- versioned_repos [component_name ][platform ][redist_version ] = repo_name + "_" + platform .replace ("-" , "_" ) + "_" + redist_version .replace ("." , "_" )
184-
185- for component_name in redist_components_mapping .keys ():
186- # Build dictionaries mapping versions to repo names for each platform.
187- x86_64_repos = {ver : versioned_repos [component_name ]["linux-x86_64" ][ver ] for ver in redist_versions if "linux-x86_64" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["linux-x86_64" ]}
188- windows_x86_64_repos = {ver : versioned_repos [component_name ]["windows-x86_64" ][ver ] for ver in redist_versions if "windows-x86_64" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["windows-x86_64" ]}
189- aarch64_repos = {ver : versioned_repos [component_name ]["linux-aarch64" ][ver ] for ver in redist_versions if "linux-aarch64" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["linux-aarch64" ]}
190- sbsa_repos = {ver : versioned_repos [component_name ]["linux-sbsa" ][ver ] for ver in redist_versions if "linux-sbsa" in versioned_repos [component_name ] and ver in versioned_repos [component_name ]["linux-sbsa" ]}
191-
192- platform_alias_repo (
193- name = redist_components_mapping [component_name ],
194- component_name = component_name ,
195- linux_x86_64_repos = x86_64_repos ,
196- windows_x86_64_repos = windows_x86_64_repos ,
197- linux_aarch64_repos = aarch64_repos ,
198- linux_sbsa_repos = sbsa_repos ,
199- versions = redist_versions ,
200- )
201- components_mapping [component_name ] = "@" + redist_components_mapping [component_name ]
179+
180+ if len (component_entries ) > 0 :
181+ components_mapping = {}
182+ redist_components_mapping = {}
183+ versioned_repos = {}
184+ for entry in component_entries .values ():
185+ component_name = entry ["component_name" ]
186+ platform = entry ["platform" ]
187+ redist_version = entry ["redist_version" ]
188+
189+ redist_components_mapping [component_name ] = entry ["repo_name" ]
190+ if component_name not in versioned_repos :
191+ versioned_repos [component_name ] = {}
192+ if platform not in versioned_repos [component_name ]:
193+ versioned_repos [component_name ][platform ] = {}
194+ versioned_repos [component_name ][platform ][redist_version ] = entry ["generated_repo_name" ]
195+
196+ for component_name in redist_components_mapping .keys ():
197+ # Build dictionaries mapping versions to repo names for each platform.
198+ platform_repo_kwargs = {}
199+ for platform in SUPPORTED_PLATFORMS :
200+ platform_repo_kwargs [_platform_repos_attr (platform )] = {
201+ ver : versioned_repos [component_name ][platform ][ver ]
202+ for ver in redist_versions
203+ if platform in versioned_repos [component_name ] and ver in versioned_repos [component_name ][platform ]
204+ }
205+
206+ platform_alias_repo (
207+ name = redist_components_mapping [component_name ],
208+ component_name = component_name ,
209+ versions = redist_versions ,
210+ ** platform_repo_kwargs
211+ )
212+ components_mapping [component_name ] = "@" + redist_components_mapping [component_name ]
202213
203214 registrations = {}
204215 for toolkit in toolkits :
0 commit comments