Skip to content

Commit a511e65

Browse files
committed
checkpoint: got host-compatible searching working, but now have 3 viable choices and unclear way to pick one
1 parent 9bc930d commit a511e65

File tree

4 files changed

+226
-28
lines changed

4 files changed

+226
-28
lines changed

MODULE.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ dev_python.override(
127127
)
128128

129129
# Necessary so single_platform_override with a new version works
130-
dev_python.toolchain(python_version = "3.13.3")
130+
dev_python.toolchain(python_version = "3.13")
131131

132132
# For testing an arbitrary runtime triggered by a custom flag.
133133
# See //tests/toolchains:custom_platform_toolchain_test

python/private/python.bzl

Lines changed: 173 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ load(":python_register_toolchains.bzl", "python_register_toolchains")
2323
load(":pythons_hub.bzl", "hub_repo")
2424
load(":repo_utils.bzl", "repo_utils")
2525
load(":semver.bzl", "semver")
26-
load(":toolchains_repo.bzl", "multi_toolchain_aliases")
26+
load(":toolchains_repo.bzl", "host_toolchain", "multi_toolchain_aliases")
2727
load(":util.bzl", "IS_BAZEL_6_4_OR_HIGHER")
2828

2929
def parse_modules(*, module_ctx, _fail = fail):
@@ -274,15 +274,29 @@ def parse_modules(*, module_ctx, _fail = fail):
274274
def _python_impl(module_ctx):
275275
py = parse_modules(module_ctx = module_ctx)
276276

277+
host_os = repo_utils.get_platforms_os_name(module_ctx)
278+
host_cpu = repo_utils.get_platforms_cpu_name(module_ctx)
279+
277280
# dict[str version, list[str] platforms]; where version is full
278281
# python version string ("3.4.5"), and platforms are platform names.
279282
loaded_platforms = {}
283+
284+
host_compatible = []
285+
286+
# list of versions
287+
host_repos_to_create = {}
280288
for toolchain_info in py.toolchains:
281289
# Ensure that we pass the full version here.
282290
full_python_version = full_version(
283291
version = toolchain_info.python_version,
284292
minor_mapping = py.config.minor_mapping,
285293
)
294+
if toolchain_info.name == "python_3_13":
295+
print("{} -> {} using minor_mapping={}".format(
296+
toolchain_info.python_version,
297+
full_python_version,
298+
py.config.minor_mapping,
299+
))
286300
kwargs = {
287301
"python_version": full_python_version,
288302
"register_coverage_tool": toolchain_info.register_coverage_tool,
@@ -292,12 +306,164 @@ def _python_impl(module_ctx):
292306
kwargs.update(py.config.kwargs.get(toolchain_info.python_version, {}))
293307
kwargs.update(py.config.kwargs.get(full_python_version, {}))
294308
kwargs.update(py.config.default)
295-
loaded_platforms[full_python_version] = python_register_toolchains(
309+
platforms = PLATFORMS | py.config.platform_overrides.get(full_python_version, {})
310+
tc_created_platforms = python_register_toolchains(
296311
name = toolchain_info.name,
297312
_internal_bzlmod_toolchain_call = True,
298-
platforms = PLATFORMS | py.config.platform_overrides.get(full_python_version, {}),
313+
platforms = platforms,
299314
**kwargs
300315
)
316+
print("{} registered platforms: {}".format(toolchain_info.name, tc_created_platforms))
317+
loaded_platforms[full_python_version] = tc_created_platforms
318+
for loaded in tc_created_platforms:
319+
info = platforms[loaded]
320+
if info.os_name == host_os and info.arch == host_cpu:
321+
host_compatible.append(struct(
322+
##repo_prefix = toolchain_info.name,
323+
repo_name = toolchain_info.name + "_" + loaded,
324+
full_python_version = full_python_version,
325+
##base_version = toolchain_info.python_version,
326+
##platform_name = loaded,
327+
##platform_info = info,
328+
))
329+
print("create host for:", toolchain_info.name, toolchain_info.python_version)
330+
331+
# "create python_3_10, using something compatible with 3.10"
332+
host_repos_to_create[toolchain_info.name] = struct(
333+
version = toolchain_info.python_version,
334+
)
335+
336+
"""
337+
goal:
338+
if name == python_3_13
339+
create host_repo() pointing to 3.13.2 <host>
340+
341+
to do this, we look through all the available versions in descending
342+
order until we find one compatible with host os/cpu.
343+
if version is major.minor:
344+
start at minor_mapping[version]
345+
full = minor_mapping[version]
346+
get all versions <= full
347+
for v in all_versions_desc:
348+
plats = PLATFORMS | overrides
349+
for plat in plats[v]:
350+
if plat compatible with host:
351+
host()
352+
return
353+
354+
if version is major.minor.micro:
355+
plats = PLATFORMS | overrides
356+
for plat in plat[v]:
357+
if plat compatible with host:
358+
host()
359+
return
360+
361+
Lets start over.
362+
Lets assume we have a list of all the created toolchains that are
363+
compatible with the host.
364+
We also have a list of all the host repos we need to create.
365+
366+
So what we have to do is, for each host repo we need to create, find
367+
the best match among the ones that were created.
368+
"best match" means:
369+
* for major.minor, the highest version <= the minor_mapping bound
370+
* for major.minor.patch: whatever is available, but maybe nothing.
371+
"""
372+
minor_mapping = py.config.minor_mapping
373+
374+
# list[tuple[tuple, str]]
375+
major_minor_lte = {}
376+
major_minor_upper = {}
377+
378+
def version_tuple(v):
379+
return tuple([int(x) for x in v.split(".")])
380+
381+
for major_minor, upper_full in minor_mapping.items():
382+
upper_v = version_tuple(upper_full)
383+
major_minor_upper[major_minor] = upper_v
384+
385+
for v in py.config.default["tool_versions"].keys():
386+
major_minor, _, _ = v.rpartition(".")
387+
major_minor_lte.setdefault(major_minor, [])
388+
vk = version_tuple(v)
389+
if vk <= major_minor_upper[major_minor]:
390+
major_minor_lte[major_minor].append((vk, v))
391+
392+
for major_minor, entry in major_minor_lte.items():
393+
major_minor_lte[major_minor] = sorted(entry, reverse = True)
394+
395+
print("major_minor_lte:")
396+
for key, values in major_minor_lte.items():
397+
print(key, ":", values)
398+
399+
# Sort in descending order so we go highest to lowest version
400+
host_compatible = sorted(
401+
host_compatible,
402+
reverse = True,
403+
key = lambda e: version_tuple(e.full_python_version),
404+
)
405+
406+
# todo:
407+
# host_compatible contains linux-regular and linux-freethreaded
408+
# Both are compatible with host_os/host_cpu
409+
# Ah, I see. host_toolchain has some logic to pick: look for an env var.
410+
# if set, use that candidate. Otherwise use candidates[0]
411+
412+
# At this point, we have:
413+
# List of host-compatible repos
414+
# Map of major_minor -> versions to attempt.
415+
# So next is iterating over the list of repos we have to create and find
416+
# something that works with it.
417+
for host_needed_base_name, host_needed in host_repos_to_create.items():
418+
print("find backend for:", host_needed_base_name, "v:", host_needed.version)
419+
needed_version_str = host_needed.version
420+
backing_repo_name = None
421+
422+
# Major.Minor case: look for a minor <= the minor_mapping bound
423+
# that is compatible with our host
424+
if needed_version_str.count(".") == 1:
425+
print("case: major.minor")
426+
try_versions = major_minor_lte[needed_version_str]
427+
try_versions = [x[1] for x in try_versions]
428+
candidates = []
429+
for entry in host_compatible:
430+
if entry.full_python_version in try_versions:
431+
#print("add candidate:", entry)
432+
candidates.append(entry)
433+
434+
def keyer(e):
435+
return (
436+
version_tuple(e.full_python_version),
437+
)
438+
439+
candidates = sorted(
440+
candidates,
441+
reverse = True,
442+
key = lambda e: (version_tuple(e.full_python_version), e.repo_name),
443+
)
444+
print("sorted candidates:")
445+
for x in candidates:
446+
print(" ", x)
447+
if candidates:
448+
backing_repo_name = candidates[0].repo_name
449+
else:
450+
fail("not implemented: fv", needed_version_str)
451+
452+
if not backing_repo_name:
453+
fail("no host-compatible repo found", host_needed)
454+
print("{} using {}".format(host_needed_base_name, backing_repo_name))
455+
host_toolchain(
456+
name = host_needed_base_name + "_host",
457+
backing_repo_name = backing_repo_name,
458+
)
459+
460+
# Now major_minor_lte maps major_minor to a list of descending full
461+
# versions that are <= the minor_mapping value.
462+
463+
print("versions:", py.config.default["tool_versions"].keys())
464+
465+
# Register a host toolchain for every major_minor and major_minor_micro
466+
# that there is a valid host platform for.
301467

302468
# List of the base names ("python_3_10") for the toolchain repos
303469
base_toolchain_repo_names = []
@@ -694,6 +860,10 @@ def _get_toolchain_config(*, modules, _fail = fail):
694860
v = semver(version_string)
695861
versions.setdefault("{}.{}".format(v.major, v.minor), []).append((int(v.patch), version_string))
696862

863+
print("version map:", versions)
864+
865+
# Ah, this is causing the minor_mapping to automatically upgrade to the
866+
# latest registered version. >.<
697867
minor_mapping = {
698868
major_minor: max(subset)[1]
699869
for major_minor, subset in versions.items()

python/private/python_register_toolchains.bzl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def python_register_toolchains(
8585
bzlmod_toolchain_call = kwargs.pop("_internal_bzlmod_toolchain_call", False)
8686
if bzlmod_toolchain_call:
8787
register_toolchains = False
88+
register_host = False
89+
else:
90+
register_host = True
8891

8992
base_url = kwargs.pop("base_url", DEFAULT_RELEASE_BASE_URL)
9093
tool_versions = tool_versions or TOOL_VERSIONS
@@ -165,11 +168,27 @@ def python_register_toolchains(
165168
platform = platform,
166169
))
167170

168-
host_toolchain(
169-
name = name + "_host",
170-
platforms = loaded_platforms,
171-
python_version = python_version,
172-
)
171+
if register_host:
172+
host_platforms = [
173+
platform
174+
for platform in loaded_platforms
175+
if platforms[platform].os_name and platforms[platform].arch
176+
]
177+
host_toolchain(
178+
name = name + "_host",
179+
platforms = host_platforms,
180+
python_version = python_version,
181+
os_names = {
182+
p: platforms[p].os_name or ""
183+
for p in host_platforms
184+
if p in platforms
185+
},
186+
archs = {
187+
p: platforms[p].arch or ""
188+
for p in host_platforms
189+
if p in platforms
190+
},
191+
)
173192

174193
toolchain_aliases(
175194
name = name,

python/private/toolchains_repo.bzl

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -234,20 +234,26 @@ def _host_toolchain_impl(rctx):
234234
235235
exports_files(["python"], visibility = ["//visibility:public"])
236236
""")
237-
238-
os_name = repo_utils.get_platforms_os_name(rctx)
239-
host_platform = _get_host_platform(
240-
rctx = rctx,
241-
logger = repo_utils.logger(rctx),
242-
python_version = rctx.attr.python_version,
243-
os_name = os_name,
244-
cpu_name = repo_utils.get_platforms_cpu_name(rctx),
245-
platforms = rctx.attr.platforms,
246-
)
247-
repo = "@@{py_repository}_{host_platform}".format(
248-
py_repository = rctx.attr.name[:-len("_host")],
249-
host_platform = host_platform,
250-
)
237+
if not rctx.attr.backing_repo_name:
238+
platforms = {
239+
p: struct(os_name = rctx.attr.os_names[p], arch = rctx.attr.archs[p])
240+
for p in rctx.attr.platforms
241+
}
242+
os_name = repo_utils.get_platforms_os_name(rctx)
243+
host_platform = _get_host_platform(
244+
rctx = rctx,
245+
logger = repo_utils.logger(rctx),
246+
python_version = rctx.attr.python_version,
247+
os_name = os_name,
248+
cpu_name = repo_utils.get_platforms_cpu_name(rctx),
249+
platforms = platforms,
250+
)
251+
repo = "@@{py_repository}_{host_platform}".format(
252+
py_repository = rctx.attr.name[:-len("_host")],
253+
host_platform = host_platform,
254+
)
255+
else:
256+
repo = rctx.attr.backing_repo_name
251257

252258
rctx.report_progress("Symlinking interpreter files to the target platform")
253259
host_python_repo = rctx.path(Label("{repo}//:BUILD.bazel".format(repo = repo)))
@@ -319,10 +325,13 @@ toolchain_aliases repo because referencing the `python` interpreter target from
319325
this repo causes an eager fetch of the toolchain for the host platform.
320326
""",
321327
attrs = {
322-
"platforms": attr.string_list(mandatory = True),
323-
"python_version": attr.string(mandatory = True),
328+
"backing_repo_name": attr.string(),
329+
"platforms": attr.string_list(mandatory = False),
330+
"python_version": attr.string(mandatory = False),
324331
"_rule_name": attr.string(default = "host_toolchain"),
325332
"_rules_python_workspace": attr.label(default = Label("//:WORKSPACE")),
333+
"os_names": attr.string_dict(),
334+
"archs": attr.string_dict(),
326335
},
327336
)
328337

@@ -420,10 +429,10 @@ def _get_host_platform(*, rctx, logger, python_version, os_name, cpu_name, platf
420429
Returns:
421430
The host platform.
422431
"""
432+
if "3_13" in rctx.name:
433+
print(rctx.name, platforms)
423434
candidates = []
424-
for platform in platforms:
425-
meta = PLATFORMS[platform]
426-
435+
for platform, meta in platforms.items():
427436
if meta.os_name == os_name and meta.arch == cpu_name:
428437
candidates.append(platform)
429438

0 commit comments

Comments
 (0)