From d7c49ee944ab6d3bb509895f139a9d69de2bf21b Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Wed, 12 Nov 2025 20:46:55 -0500 Subject: [PATCH 1/4] pyo3: support module prefix + naming --- extensions/pyo3/private/pyo3.bzl | 33 +++++++++++++++---- .../pyo3/test/module_prefix/BUILD.bazel | 17 ++++++++++ extensions/pyo3/test/module_prefix/bar.rs | 12 +++++++ .../module_prefix_import_test.py | 15 +++++++++ 4 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 extensions/pyo3/test/module_prefix/BUILD.bazel create mode 100644 extensions/pyo3/test/module_prefix/bar.rs create mode 100644 extensions/pyo3/test/module_prefix/module_prefix_import_test.py diff --git a/extensions/pyo3/private/pyo3.bzl b/extensions/pyo3/private/pyo3.bzl index 0e087b55c8..0c936ee7a9 100644 --- a/extensions/pyo3/private/pyo3.bzl +++ b/extensions/pyo3/private/pyo3.bzl @@ -87,10 +87,19 @@ def _py_pyo3_library_impl(ctx): is_windows = extension.basename.endswith(".dll") # https://pyo3.rs/v0.26.0/building-and-distribution#manual-builds - ext = ctx.actions.declare_file("{}{}".format( - ctx.label.name, - ".pyd" if is_windows else ".so", - )) + # Determine the on-disk and logical Python module layout. + module_name = ctx.attr.module if ctx.attr.module else ctx.label.name + + # Convert a dotted prefix (e.g. "foo.bar") into a path ("foo/bar"). + if ctx.attr.module_prefix: + module_prefix_path = ctx.attr.module_prefix.replace(".", "/") + module_relpath = "{}/{}.{}".format(module_prefix_path, module_name, "pyd" if is_windows else "so") + stub_relpath = "{}/{}.pyi".format(module_prefix_path, module_name) + else: + module_relpath = "{}.{}".format(module_name, "pyd" if is_windows else "so") + stub_relpath = "{}.pyi".format(module_name) + + ext = ctx.actions.declare_file(module_relpath) ctx.actions.symlink( output = ext, target_file = extension, @@ -99,10 +108,10 @@ def _py_pyo3_library_impl(ctx): stub = None if _stubs_enabled(ctx.attr.stubs, toolchain): - stub = ctx.actions.declare_file("{}.pyi".format(ctx.label.name)) + stub = ctx.actions.declare_file(stub_relpath) args = ctx.actions.args() - args.add(ctx.label.name, format = "--module_name=%s") + args.add(module_name, format = "--module_name=%s") args.add(ext, format = "--module_path=%s") args.add(stub, format = "--output=%s") ctx.actions.run( @@ -180,6 +189,12 @@ py_pyo3_library = rule( "imports": attr.string_list( doc = "List of import directories to be added to the `PYTHONPATH`.", ), + "module": attr.string( + doc = "The Python module name implemented by this extension.", + ), + "module_prefix": attr.string( + doc = "A dotted Python package prefix for the module.", + ), "stubs": attr.int( doc = "Whether or not to generate stubs. `-1` will default to the global config, `0` will never generate, and `1` will always generate stubs.", default = -1, @@ -218,6 +233,8 @@ def pyo3_extension( stubs = None, version = None, compilation_mode = "opt", + module = None, + module_prefix = None, **kwargs): """Define a PyO3 python extension module. @@ -259,6 +276,8 @@ def pyo3_extension( For more details see [rust_shared_library][rsl]. compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode) value to build the extension for. If set to `"current"`, the current configuration will be used. + module (str, optional): The Python module name implemented by this extension. + module_prefix (str, optional): A dotted Python package prefix for the module. **kwargs (dict): Additional keyword arguments. """ tags = kwargs.pop("tags", []) @@ -318,6 +337,8 @@ def pyo3_extension( compilation_mode = compilation_mode, stubs = stubs_int, imports = imports, + module = module, + module_prefix = module_prefix, tags = tags, visibility = visibility, **kwargs diff --git a/extensions/pyo3/test/module_prefix/BUILD.bazel b/extensions/pyo3/test/module_prefix/BUILD.bazel new file mode 100644 index 0000000000..75dfc8f1a0 --- /dev/null +++ b/extensions/pyo3/test/module_prefix/BUILD.bazel @@ -0,0 +1,17 @@ +load("@rules_python//python:defs.bzl", "py_test") +load("//:defs.bzl", "pyo3_extension") + +pyo3_extension( + name = "module_prefix", + srcs = ["bar.rs"], + edition = "2021", + imports = ["."], + module = "bar", + module_prefix = "foo", +) + +py_test( + name = "module_prefix_import_test", + srcs = ["module_prefix_import_test.py"], + deps = [":module_prefix"], +) diff --git a/extensions/pyo3/test/module_prefix/bar.rs b/extensions/pyo3/test/module_prefix/bar.rs new file mode 100644 index 0000000000..cfddb13a17 --- /dev/null +++ b/extensions/pyo3/test/module_prefix/bar.rs @@ -0,0 +1,12 @@ +use pyo3::prelude::*; + +#[pyfunction] +fn thing() -> PyResult<&'static str> { + Ok("hello from rust") +} + +#[pymodule] +fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(thing, m)?)?; + Ok(()) +} diff --git a/extensions/pyo3/test/module_prefix/module_prefix_import_test.py b/extensions/pyo3/test/module_prefix/module_prefix_import_test.py new file mode 100644 index 0000000000..0a9d84622c --- /dev/null +++ b/extensions/pyo3/test/module_prefix/module_prefix_import_test.py @@ -0,0 +1,15 @@ +"""Tests that a pyo3 extension can be imported via a module prefix.""" + +import unittest + +import foo.bar + + +class ModulePrefixImportTest(unittest.TestCase): + def test_import_and_call(self) -> None: + result = foo.bar.thing() + self.assertEqual("hello from rust", result) + + +if __name__ == "__main__": + unittest.main() From e19bd9cdf3789eccbed21731625e8bb98df70a27 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Fri, 14 Nov 2025 21:42:53 -0800 Subject: [PATCH 2/4] updates for linters --- .../pyo3/test/module_prefix/module_prefix_import_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/extensions/pyo3/test/module_prefix/module_prefix_import_test.py b/extensions/pyo3/test/module_prefix/module_prefix_import_test.py index 0a9d84622c..389c51d4a3 100644 --- a/extensions/pyo3/test/module_prefix/module_prefix_import_test.py +++ b/extensions/pyo3/test/module_prefix/module_prefix_import_test.py @@ -2,11 +2,15 @@ import unittest -import foo.bar +import foo.bar # type: ignore class ModulePrefixImportTest(unittest.TestCase): + """Test Class.""" + def test_import_and_call(self) -> None: + """Test that a pyo3 extension can be imported via a module prefix.""" + result = foo.bar.thing() self.assertEqual("hello from rust", result) From 63c4c1d86f2295a04132fcaa3e7ea42a501f9646 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Fri, 28 Nov 2025 08:41:05 -0800 Subject: [PATCH 3/4] move to a single module arg --- extensions/pyo3/private/pyo3.bzl | 42 +++++++++++++------ .../pyo3/test/module_prefix/BUILD.bazel | 3 +- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/extensions/pyo3/private/pyo3.bzl b/extensions/pyo3/private/pyo3.bzl index 0c936ee7a9..110d040ed6 100644 --- a/extensions/pyo3/private/pyo3.bzl +++ b/extensions/pyo3/private/pyo3.bzl @@ -87,12 +87,36 @@ def _py_pyo3_library_impl(ctx): is_windows = extension.basename.endswith(".dll") # https://pyo3.rs/v0.26.0/building-and-distribution#manual-builds + # # Determine the on-disk and logical Python module layout. - module_name = ctx.attr.module if ctx.attr.module else ctx.label.name + # + # `module` is a full dotted module path (e.g. "foo.bar"). We split on the + # last "." such that: + # - module_prefix == "foo" + # - module_name == "bar" + # + # `module_name` must match the `#[pymodule] fn (...)` in the Rust code + # and is also what we pass to the stub generator. + module_path = ctx.attr.module if ctx.attr.module else ctx.label.name.replace("/", ".") + + if module_path.startswith(".") or module_path.endswith(".") or ".." in module_path: + fail("Invalid `module` value '{}': expected a dotted module path like 'foo.bar'.".format(module_path)) + + last_dot = module_path.rfind(".") + if last_dot == -1: + module_prefix = None + module_name = module_path + else: + module_prefix = module_path[:last_dot] + module_name = module_path[last_dot + 1:] + + if not module_name: + fail("Invalid `module` value '{}': module name may not be empty.".format(module_path)) - # Convert a dotted prefix (e.g. "foo.bar") into a path ("foo/bar"). - if ctx.attr.module_prefix: - module_prefix_path = ctx.attr.module_prefix.replace(".", "/") + # Convert module_prefix (e.g. "foo.bar") into a path ("foo/bar") and place + # the extension and stubs in the corresponding directory. + if module_prefix: + module_prefix_path = module_prefix.replace(".", "/") module_relpath = "{}/{}.{}".format(module_prefix_path, module_name, "pyd" if is_windows else "so") stub_relpath = "{}/{}.pyi".format(module_prefix_path, module_name) else: @@ -190,10 +214,7 @@ py_pyo3_library = rule( doc = "List of import directories to be added to the `PYTHONPATH`.", ), "module": attr.string( - doc = "The Python module name implemented by this extension.", - ), - "module_prefix": attr.string( - doc = "A dotted Python package prefix for the module.", + doc = "A full dotted Python module path implemented by this extension (e.g. `foo.bar`).", ), "stubs": attr.int( doc = "Whether or not to generate stubs. `-1` will default to the global config, `0` will never generate, and `1` will always generate stubs.", @@ -234,7 +255,6 @@ def pyo3_extension( version = None, compilation_mode = "opt", module = None, - module_prefix = None, **kwargs): """Define a PyO3 python extension module. @@ -276,8 +296,7 @@ def pyo3_extension( For more details see [rust_shared_library][rsl]. compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode) value to build the extension for. If set to `"current"`, the current configuration will be used. - module (str, optional): The Python module name implemented by this extension. - module_prefix (str, optional): A dotted Python package prefix for the module. + module (str, optional): A full dotted Python module path implemented by this extension (e.g. `foo.bar`). **kwargs (dict): Additional keyword arguments. """ tags = kwargs.pop("tags", []) @@ -338,7 +357,6 @@ def pyo3_extension( stubs = stubs_int, imports = imports, module = module, - module_prefix = module_prefix, tags = tags, visibility = visibility, **kwargs diff --git a/extensions/pyo3/test/module_prefix/BUILD.bazel b/extensions/pyo3/test/module_prefix/BUILD.bazel index 75dfc8f1a0..75c2929a0e 100644 --- a/extensions/pyo3/test/module_prefix/BUILD.bazel +++ b/extensions/pyo3/test/module_prefix/BUILD.bazel @@ -6,8 +6,7 @@ pyo3_extension( srcs = ["bar.rs"], edition = "2021", imports = ["."], - module = "bar", - module_prefix = "foo", + module = "foo.bar", ) py_test( From f6e336a4bb03ba56c2a240a619750d101994fe78 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Fri, 28 Nov 2025 08:47:36 -0800 Subject: [PATCH 4/4] -> module_name --- extensions/pyo3/private/pyo3.bzl | 10 +++++----- extensions/pyo3/test/module_prefix/BUILD.bazel | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/extensions/pyo3/private/pyo3.bzl b/extensions/pyo3/private/pyo3.bzl index 110d040ed6..b16871efb6 100644 --- a/extensions/pyo3/private/pyo3.bzl +++ b/extensions/pyo3/private/pyo3.bzl @@ -97,7 +97,7 @@ def _py_pyo3_library_impl(ctx): # # `module_name` must match the `#[pymodule] fn (...)` in the Rust code # and is also what we pass to the stub generator. - module_path = ctx.attr.module if ctx.attr.module else ctx.label.name.replace("/", ".") + module_path = ctx.attr.module_name if ctx.attr.module_name else ctx.label.name.replace("/", ".") if module_path.startswith(".") or module_path.endswith(".") or ".." in module_path: fail("Invalid `module` value '{}': expected a dotted module path like 'foo.bar'.".format(module_path)) @@ -213,7 +213,7 @@ py_pyo3_library = rule( "imports": attr.string_list( doc = "List of import directories to be added to the `PYTHONPATH`.", ), - "module": attr.string( + "module_name": attr.string( doc = "A full dotted Python module path implemented by this extension (e.g. `foo.bar`).", ), "stubs": attr.int( @@ -254,7 +254,7 @@ def pyo3_extension( stubs = None, version = None, compilation_mode = "opt", - module = None, + module_name = None, **kwargs): """Define a PyO3 python extension module. @@ -296,7 +296,7 @@ def pyo3_extension( For more details see [rust_shared_library][rsl]. compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode) value to build the extension for. If set to `"current"`, the current configuration will be used. - module (str, optional): A full dotted Python module path implemented by this extension (e.g. `foo.bar`). + module_name (str, optional): A full dotted Python module path implemented by this extension (e.g. `foo.bar`). **kwargs (dict): Additional keyword arguments. """ tags = kwargs.pop("tags", []) @@ -356,7 +356,7 @@ def pyo3_extension( compilation_mode = compilation_mode, stubs = stubs_int, imports = imports, - module = module, + module_name = module_name, tags = tags, visibility = visibility, **kwargs diff --git a/extensions/pyo3/test/module_prefix/BUILD.bazel b/extensions/pyo3/test/module_prefix/BUILD.bazel index 75c2929a0e..e21fb63e88 100644 --- a/extensions/pyo3/test/module_prefix/BUILD.bazel +++ b/extensions/pyo3/test/module_prefix/BUILD.bazel @@ -6,7 +6,7 @@ pyo3_extension( srcs = ["bar.rs"], edition = "2021", imports = ["."], - module = "foo.bar", + module_name = "foo.bar", ) py_test(