diff --git a/buf/extensions.bzl b/buf/extensions.bzl index 4376d76..8569598 100644 --- a/buf/extensions.bzl +++ b/buf/extensions.bzl @@ -16,8 +16,8 @@ See https://bazel.build/docs/bzlmod#extension-definition """ -load("//buf/internal:toolchain.bzl", "buf_download_releases") load("//buf/internal:repo.bzl", "buf_dependencies") +load("//buf/internal:toolchain.bzl", "buf_download_releases", "buf_register_toolchains") _DEFAULT_VERSION = "v1.47.2" _DEFAULT_SHA256 = "1b37b75dc0a777a0cba17fa2604bc9906e55bb4c578823d8b7a8fe3fc9fe4439" @@ -72,11 +72,26 @@ def _extension_impl(module_ctx): print("NOTE: buf toolchains {} has multiple versions {}, selected {}".format(name, versions, selected)) else: selected = versions[0] - buf_download_releases( - name = name, - version = selected["version"], - sha256 = selected["sha256"], - ) + platforms_for_registration = [] + for platform in ( + struct(os = "linux", arch = "arm64"), + struct(os = "linux", arch = "amd64"), + struct(os = "darwin", arch = "arm64"), + struct(os = "darwin", arch = "amd64"), + struct(os = "windows", arch = "arm64"), + struct(os = "windows", arch = "amd64"), + ): + name_with_platform = "{}_{}_{}".format(name, platform.os, platform.arch) + buf_download_releases( + name = name_with_platform, + os = platform.os, + arch = platform.arch, + version = selected["version"], + sha256 = selected["sha256"], + ) + platforms_for_registration.append("{}-{}".format(platform.os, platform.arch)) + + buf_register_toolchains(name = name, platforms = platforms_for_registration) for name, modules in dependencies.items(): buf_dependencies( diff --git a/buf/internal/repo.bzl b/buf/internal/repo.bzl index d07bfdb..2381841 100644 --- a/buf/internal/repo.bzl +++ b/buf/internal/repo.bzl @@ -23,6 +23,22 @@ _DOC = """ For more info please refer to the [`buf_dependencies` section](https://docs.buf.build/build-systems/bazel#buf-dependencies) of the docs. """ +# Copied from rules_go: https://github.com/bazelbuild/rules_go/blob/19ad920c6869a179d186a365d117ab82f38d0f3a/go/private/sdk.bzl#L517 +def _detect_host_platform(ctx): + goos = ctx.os.name + if goos == "mac os x": + goos = "osx" + elif goos.startswith("windows"): + goos = "windows" + + goarch = ctx.os.arch + if goarch == "aarch64": + goarch = "arm64" + elif goarch == "x86_64": + goarch = "amd64" + + return goos, goarch + def _executable_extension(ctx): extension = "" if ctx.os.name.startswith("windows"): @@ -41,7 +57,18 @@ def _valid_pin(pin): return True def _buf_dependencies_impl(ctx): - buf = ctx.path(Label("@{}//:buf{}".format(ctx.attr.toolchain_repo, _executable_extension(ctx)))) + host_os, host_arch = _detect_host_platform(ctx) + binary_real_repo = "@@{workspace}_{host_os}_{host_arch}".format( + workspace = Label("@{toolchain_repo}".format(toolchain_repo = ctx.attr.toolchain_repo)).workspace_name, + host_os = host_os, + host_arch = host_arch, + ) + buf = ctx.path( + Label("{binary_real_repo}//:buf{executable_extension}".format( + binary_real_repo = binary_real_repo, + executable_extension = _executable_extension(ctx), + )), + ) for pin in ctx.attr.modules: if not _valid_pin(pin): diff --git a/buf/internal/toolchain.bzl b/buf/internal/toolchain.bzl index 26c1ed9..bafbeab 100644 --- a/buf/internal/toolchain.bzl +++ b/buf/internal/toolchain.bzl @@ -19,18 +19,14 @@ load("@bazel_tools//tools/build_defs/repo:utils.bzl", "update_attrs") _TOOLCHAINS_REPO = "rules_buf_toolchains" _BUILD_FILE = """ -load(":toolchain.bzl", "declare_buf_toolchains") - package(default_visibility = ["//visibility:public"]) -declare_buf_toolchains( - os = "{os}", - cpu = "{cpu}", - rules_buf_repo_name = "{rules_buf_repo_name}", - ) +load(":toolchain.bzl", "implement_buf_toolchains") + +implement_buf_toolchains("{cmd_suffix}") """ -_TOOLCHAIN_FILE = """ +_IMPLEMENT_TOOLCHAIN_FILE = """ def _buf_toolchain_impl(ctx): toolchain_info = platform_common.ToolchainInfo( cli = ctx.executable.cli, @@ -50,51 +46,79 @@ _buf_toolchain = rule( }, ) -def declare_buf_toolchains(os, cpu, rules_buf_repo_name): +def implement_buf_toolchains(cmd_suffix): for cmd in ["buf", "protoc-gen-buf-lint", "protoc-gen-buf-breaking"]: - cmd_suffix = "" - if os == "windows": - cmd_suffix = ".exe" toolchain_impl = cmd + "_toolchain_impl" _buf_toolchain( name = toolchain_impl, cli = str(Label("//:"+ cmd + cmd_suffix)), ) + +""" + +_DECLARE_TOOLCHAINS_HEAD = """ +def declare_buf_toolchains(rules_buf_repo_name): + for cmd in ["buf", "protoc-gen-buf-lint", "protoc-gen-buf-breaking"]: + toolchain_impl = cmd + "_toolchain_impl" +""" + +_DECLARE_TOOLCHAINS_CALL = """ native.toolchain( - name = cmd + "_toolchain", - toolchain = ":" + toolchain_impl, - toolchain_type = "@@{}//tools/{}:toolchain_type".format(rules_buf_repo_name, cmd), + name = cmd + "_{os}_{arch}_toolchain", + toolchain = "@@{name}_{os}_{arch}//:" + toolchain_impl, + toolchain_type = "@@{{}}//tools/{{}}:toolchain_type".format(rules_buf_repo_name, cmd), exec_compatible_with = [ - "@platforms//os:" + os, - "@platforms//cpu:" + cpu, + "@platforms//os:{os}", + "@platforms//cpu:{cpu}", ], ) +""" + +_DECLARE_TOOLCHAINS_BUILD_FILE = """ +package(default_visibility = ["//visibility:public"]) +load(":toolchain.bzl", "declare_buf_toolchains") + +declare_buf_toolchains( + rules_buf_repo_name = "{rules_buf_repo_name}", +) """ -def _register_toolchains(repo, cmd): +def _register_toolchains(repo): native.register_toolchains( - "@{repo}//:{cmd}_toolchain".format( + "@{repo}//:all".format( repo = repo, - cmd = cmd, ), ) -# Copied from rules_go: https://github.com/bazelbuild/rules_go/blob/19ad920c6869a179d186a365d117ab82f38d0f3a/go/private/sdk.bzl#L517 -def _detect_host_platform(ctx): - goos = ctx.os.name - if goos == "mac os x": - goos = "darwin" - elif goos.startswith("windows"): - goos = "windows" - - goarch = ctx.os.arch - if goarch == "aarch64": - goarch = "arm64" - elif goarch == "x86_64": - goarch = "amd64" - - return goos, goarch +def _buf_register_toolchains_impl(ctx): + platforms = ctx.attr.platforms # list of "{os}-{arch}" strings + ctx.file( + "BUILD", + _DECLARE_TOOLCHAINS_BUILD_FILE.format( + rules_buf_repo_name = Label("//buf/repositories.bzl").workspace_name, + ), + ) + toolchain_file_text = _DECLARE_TOOLCHAINS_HEAD + for platform in platforms: + os, arch = platform.split("-", 1) + if os == "darwin": + os = "osx" + cpu = arch + if cpu == "amd64": + cpu = "x86_64" + toolchain_file_text += _DECLARE_TOOLCHAINS_CALL.format(name = ctx.attr.name, os = os, arch = arch, cpu = cpu) + ctx.file("toolchain.bzl", toolchain_file_text) + +buf_register_toolchains = repository_rule( + implementation = _buf_register_toolchains_impl, + attrs = { + "platforms": attr.string_list( + doc = "Buf platforms", + mandatory = True, + ), + }, +) def _buf_download_releases_impl(ctx): version = ctx.attr.version @@ -111,15 +135,14 @@ def _buf_download_releases_impl(ctx): version_data = ctx.read("version.json") version = json.decode(version_data)["name"] - os, cpu = _detect_host_platform(ctx) - if os not in ["linux", "darwin", "windows"] or cpu not in ["arm64", "amd64"]: - fail("Unsupported operating system or cpu architecture ") + os = ctx.attr.os + cpu = ctx.attr.arch if os == "linux" and cpu == "arm64": cpu = "aarch64" if cpu == "amd64": cpu = "x86_64" - ctx.report_progress("Downloading buf release hash") + ctx.report_progress("Downloading buf release hash for {}-{}".format(os, cpu)) url = "{}/{}/sha256.txt".format(repository_url, version) sha256 = ctx.download( url = url, @@ -128,7 +151,7 @@ def _buf_download_releases_impl(ctx): output = "sha256.txt", ).sha256 ctx.file("WORKSPACE", "workspace(name = \"{name}\")".format(name = ctx.name)) - ctx.file("toolchain.bzl", _TOOLCHAIN_FILE) + ctx.file("toolchain.bzl", _IMPLEMENT_TOOLCHAIN_FILE) sha_list = ctx.read("sha256.txt").splitlines() for sha_line in sha_list: if sha_line.strip(" ").endswith(".tar.gz"): @@ -153,16 +176,9 @@ def _buf_download_releases_impl(ctx): output = output, ) - if os == "darwin": - os = "osx" - ctx.file( "BUILD", - _BUILD_FILE.format( - os = os, - cpu = cpu, - rules_buf_repo_name = Label("//buf/repositories.bzl").workspace_name, - ), + _BUILD_FILE.format(cmd_suffix = ".exe" if os == "windows" else ""), ) attrs = {"version": version, "repository_url": repository_url, "sha256": sha256} return update_attrs(ctx.attr, attrs.keys(), attrs) @@ -170,6 +186,16 @@ def _buf_download_releases_impl(ctx): buf_download_releases = repository_rule( implementation = _buf_download_releases_impl, attrs = { + "os": attr.string( + doc = "Buf release os", + mandatory = True, + values = ["linux", "darwin", "windows"], + ), + "arch": attr.string( + doc = "Buf release cpu arch", + mandatory = True, + values = ["arm64", "amd64"], + ), "version": attr.string( doc = "Buf release version", ), @@ -194,8 +220,18 @@ def rules_buf_toolchains(name = _TOOLCHAINS_REPO, version = None, sha256 = None, repository_url: The repository url base used for downloads. Defaults to "https://github.com/bufbuild/buf/releases/download" """ - buf_download_releases(name = name, version = version, sha256 = sha256, repository_url = repository_url) - - _register_toolchains(name, "buf") - _register_toolchains(name, "protoc-gen-buf-breaking") - _register_toolchains(name, "protoc-gen-buf-lint") + platforms_for_registration = [] + for platform in ( + struct(os = "linux", arch = "arm64"), + struct(os = "linux", arch = "amd64"), + struct(os = "darwin", arch = "arm64"), + struct(os = "darwin", arch = "amd64"), + struct(os = "windows", arch = "arm64"), + struct(os = "windows", arch = "amd64"), + ): + name_with_platform = "{}_{}_{}".format(name, platform.os, platform.arch) + buf_download_releases(name = name_with_platform, os = platform.os, arch = platform.arch, version = version, sha256 = sha256, repository_url = repository_url) + platforms_for_registration.append("{}-{}".format(platform.os, platform.arch)) + + buf_register_toolchains(name = name, platforms = platforms_for_registration) + _register_toolchains(name)