Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 70 additions & 33 deletions python/private/python.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ load(":full_version.bzl", "full_version")
load(":python_register_toolchains.bzl", "python_register_toolchains")
load(":pythons_hub.bzl", "hub_repo")
load(":repo_utils.bzl", "repo_utils")
load(":toolchains_repo.bzl", "multi_toolchain_aliases")
load(":toolchains_repo.bzl", "host_toolchain", "multi_toolchain_aliases")
load(":util.bzl", "IS_BAZEL_6_4_OR_HIGHER")
load(":version.bzl", "version")

Expand Down Expand Up @@ -267,11 +267,18 @@ def parse_modules(*, module_ctx, _fail = fail):
def _python_impl(module_ctx):
py = parse_modules(module_ctx = module_ctx)

# dict[str version, list[str] platforms]; where version is full
# python version string ("3.4.5"), and platforms are keys from
# the PLATFORMS global.
loaded_platforms = {}
for toolchain_info in py.toolchains:
# list of structs; see inline struct call within the loop below.
toolchain_impls = []

# list[str] of the base names of toolchain repos
base_toolchain_repo_names = []

# Create the underlying python_repository repos that contain the
# python runtimes and their toolchain implementation definitions.
for i, toolchain_info in enumerate(py.toolchains):
is_last = (i + 1) == len(py.toolchains)
base_toolchain_repo_names.append(toolchain_info.name)

# Ensure that we pass the full version here.
full_python_version = full_version(
version = toolchain_info.python_version,
Expand All @@ -286,12 +293,45 @@ def _python_impl(module_ctx):
kwargs.update(py.config.kwargs.get(toolchain_info.python_version, {}))
kwargs.update(py.config.kwargs.get(full_python_version, {}))
kwargs.update(py.config.default)
toolchain_registered_platforms = python_register_toolchains(
register_result = python_register_toolchains(
name = toolchain_info.name,
_internal_bzlmod_toolchain_call = True,
**kwargs
)
loaded_platforms[full_python_version] = toolchain_registered_platforms
host_platforms = []
host_os_names = {}
host_archs = {}
for repo_name, (platform_name, platform_info) in register_result.impl_repos.items():
toolchain_impls.append(struct(
# str: The base name to use for the toolchain() target
name = repo_name,
# str: The repo name the toolchain() target points to.
impl_repo_name = repo_name,
# str: platform key in the passed-in platforms dict
platform_name = platform_name,
# struct: platform_info() struct
platform = platform_info,
# str: Major.Minor.Micro python version
full_python_version = full_python_version,
# bool: whether to implicitly add the python version constraint
# to the toolchain's target_settings.
# The last toolchain is the default; it can't have version constraints
set_python_version_constraint = is_last,
))
if _is_compatible_with_host(module_ctx, platform_info):
host_key = str(len(host_platforms))
host_platforms.append(platform_name)
host_os_names[host_key] = platform_info.os_name
host_archs[host_key] = platform_info.arch

host_toolchain(
name = toolchain_info.name + "_host",
# NOTE: Order matters. The first found to be compatible is (usually) used.
platforms = host_platforms,
os_names = host_os_names,
archs = host_archs,
python_version = full_python_version,
)

# List of the base names ("python_3_10") for the toolchain repos
base_toolchain_repo_names = []
Expand Down Expand Up @@ -329,31 +369,23 @@ def _python_impl(module_ctx):

# Split the toolchain info into separate objects so they can be passed onto
# the repository rule.
for i, t in enumerate(py.toolchains):
is_last = (i + 1) == len(py.toolchains)
base_name = t.name
base_toolchain_repo_names.append(base_name)
fv = full_version(version = t.python_version, minor_mapping = py.config.minor_mapping)
platforms = loaded_platforms[fv]
for platform_name, platform_info in platforms.items():
key = str(len(toolchain_names))

full_name = "{}_{}".format(base_name, platform_name)
toolchain_names.append(full_name)
toolchain_repo_names[key] = full_name
toolchain_tcw_map[key] = platform_info.compatible_with

# The target_settings attribute may not be present for users
# patching python/versions.bzl.
toolchain_ts_map[key] = getattr(platform_info, "target_settings", [])
toolchain_platform_keys[key] = platform_name
toolchain_python_versions[key] = fv

# The last toolchain is the default; it can't have version constraints
# Despite the implication of the arg name, the values are strs, not bools
toolchain_set_python_version_constraints[key] = (
"True" if not is_last else "False"
)
for entry in toolchain_impls:
key = str(len(toolchain_names))

toolchain_names.append(entry.name)
toolchain_repo_names[key] = entry.impl_repo_name
toolchain_tcw_map[key] = entry.platform.compatible_with

# The target_settings attribute may not be present for users
# patching python/versions.bzl.
toolchain_ts_map[key] = getattr(entry.platform, "target_settings", [])
toolchain_platform_keys[key] = entry.platform_name
toolchain_python_versions[key] = entry.full_python_version

# Repo rules can't accept dict[str, bool], so encode them as a string value.
toolchain_set_python_version_constraints[key] = (
"True" if entry.set_python_version_constraint else "False"
)

hub_repo(
name = "pythons_hub",
Expand Down Expand Up @@ -391,6 +423,11 @@ def _python_impl(module_ctx):
else:
return None

def _is_compatible_with_host(mctx, platform_info):
os_name = repo_utils.get_platforms_os_name(mctx)
cpu_name = repo_utils.get_platforms_cpu_name(mctx)
return platform_info.os_name == os_name and platform_info.arch == cpu_name

def _one_or_the_same(first, second, *, onerror = None):
if not first:
return second
Expand Down
40 changes: 23 additions & 17 deletions python/private/python_register_toolchains.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,17 @@ def python_register_toolchains(
))
register_coverage_tool = False

loaded_platforms = {}
for platform in platforms.keys():
# list[str] of the platform names that were used
loaded_platforms = []

# dict[str repo name, tuple[str, platform_info]]
impl_repos = {}
for platform, platform_info in platforms.items():
sha256 = tool_versions[python_version]["sha256"].get(platform, None)
if not sha256:
continue

loaded_platforms[platform] = platforms[platform]
loaded_platforms.append(platform)
(release_filename, urls, strip_prefix, patches, patch_strip) = get_release_info(platform, python_version, base_url, tool_versions)

# allow passing in a tool version
Expand All @@ -137,11 +141,10 @@ def python_register_toolchains(
)],
)

impl_repo_name = "{}_{}".format(name, platform)
impl_repos[impl_repo_name] = (platform, platform_info)
python_repository(
name = "{name}_{platform}".format(
name = name,
platform = platform,
),
name = impl_repo_name,
sha256 = sha256,
patches = patches,
patch_strip = patch_strip,
Expand All @@ -167,28 +170,31 @@ def python_register_toolchains(
platform = platform,
))

host_toolchain(
name = name + "_host",
platforms = loaded_platforms.keys(),
python_version = python_version,
)

toolchain_aliases(
name = name,
python_version = python_version,
user_repository_name = name,
platforms = loaded_platforms.keys(),
platforms = loaded_platforms,
)

# in bzlmod we write out our own toolchain repos
# in bzlmod we write out our own toolchain repos and host repos
if bzlmod_toolchain_call:
return loaded_platforms
return struct(
# dict[str name, tuple[str platform_name, platform_info]]
impl_repos = impl_repos,
)

host_toolchain(
name = name + "_host",
platforms = loaded_platforms,
python_version = python_version,
)

toolchains_repo(
name = toolchain_repo_name,
python_version = python_version,
set_python_version_constraint = set_python_version_constraint,
user_repository_name = name,
platforms = loaded_platforms.keys(),
platforms = loaded_platforms,
)
return None
25 changes: 24 additions & 1 deletion python/private/toolchains_repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def _host_toolchain_impl(rctx):
if not rctx.delete(python_tester):
fail("Failed to delete the python tester")

# NOTE: The term "toolchain" is a misnomer for this rule. This doesn't define
# a repo with toolchains or toolchain implementations.
host_toolchain = repository_rule(
_host_toolchain_impl,
doc = """\
Expand All @@ -384,6 +386,16 @@ toolchain_aliases repo because referencing the `python` interpreter target from
this repo causes an eager fetch of the toolchain for the host platform.
""",
attrs = {
"archs": attr.string_dict(
doc = """
If set, overrides the platform metadata. Keyed by index in `platforms`
""",
),
"os_names": attr.string_dict(
doc = """
If set, overrides the platform metadata. Keyed by index in `platforms`
""",
),
"platforms": attr.string_list(mandatory = True),
"python_version": attr.string(mandatory = True),
"_rule_name": attr.string(default = "host_toolchain"),
Expand Down Expand Up @@ -434,9 +446,20 @@ def _get_host_platform(*, rctx, logger, python_version, os_name, cpu_name, platf
Returns:
The host platform.
"""
if rctx.attr.os_names:
platform_map = {}
for i, platform_name in enumerate(platforms):
key = str(i)
platform_map[platform_name] = struct(
os_name = rctx.attr.os_names[key],
arch = rctx.attr.archs[key],
)
else:
platform_map = PLATFORMS

candidates = []
for platform in platforms:
meta = PLATFORMS[platform]
meta = platform_map[platform]

if meta.os_name == os_name and meta.arch == cpu_name:
candidates.append(platform)
Expand Down