Skip to content

Commit b80e76c

Browse files
jbmscopybara-github
authored andcommitted
Add .pyi type stub generation for tensorstore
PiperOrigin-RevId: 823077227 Change-Id: Ib938c5e642cfde99443b30b4f8c88a837718648c
1 parent 845c6bc commit b80e76c

File tree

13 files changed

+1071
-215
lines changed

13 files changed

+1071
-215
lines changed

bazel/pybind11.bzl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Supports pybind11 extension modules"""
1616

1717
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
18+
load("@bazel_skylib//rules:write_file.bzl", "write_file")
1819
load("@rules_cc//cc:cc_binary.bzl", "cc_binary")
1920
load("@rules_cc//cc:cc_library.bzl", "cc_library")
2021
load(
@@ -73,17 +74,15 @@ def py_extension(
7374
exported_symbol = "PyInit_" + name
7475

7576
# Generate linker script used on non-macOS unix platforms.
76-
native.genrule(
77+
write_file(
7778
name = linker_script_name_rule,
78-
outs = [linker_script_name],
79-
cmd = "\n".join([
80-
"cat <<'EOF' >$@",
79+
out = linker_script_name,
80+
content = [
8181
"{",
8282
" global: " + exported_symbol + ";",
8383
" local: *;",
8484
"};",
85-
"EOF",
86-
]),
85+
],
8786
)
8887

8988
for cc_binary_name in [cc_binary_dll_name, cc_binary_so_name]:

bazel/pytype.bzl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ load("@rules_python//python:py_binary.bzl", "py_binary")
1818
load("@rules_python//python:py_library.bzl", "py_library")
1919
load("@rules_python//python:py_test.bzl", "py_test")
2020

21-
def pytype_strict_library(**kwargs):
21+
def pytype_strict_library(
22+
pytype_srcs = None, # @unused
23+
**kwargs):
2224
"""Python type checking not currently supported in open source builds."""
2325
py_library(**kwargs)
2426

@@ -29,3 +31,9 @@ def pytype_strict_binary(**kwargs):
2931
def pytype_strict_test(**kwargs):
3032
"""Python type checking not currently supported in open source builds."""
3133
py_test(**kwargs)
34+
35+
def pytype_stub_library(
36+
srcs = None, # @unused
37+
**kwargs):
38+
"""Python type checking not currently supported in open source builds."""
39+
py_library(srcs = [], **kwargs)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright 2025 The TensorStore Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Version of run_binary from bazel-skylib that may avoid a separate exec build."""
15+
16+
# The `value` is a list of Label objects indicating which platform constraints are satisfied.
17+
PlatformConstraintsInfo = provider(
18+
"List of satisfied platform constraints.",
19+
fields = ["value"],
20+
)
21+
22+
# This should include all constraints that are needed to identify cases where the target and exec
23+
# platforms are incompatible.
24+
PLATFORM_CONSTRAINTS = [
25+
"@platforms//cpu:x86_64",
26+
"@platforms//cpu:arm64",
27+
"@platforms//cpu:ppc64le",
28+
"@platforms//os:windows",
29+
"@platforms//os:linux",
30+
"@platforms//os:macos",
31+
"@platforms//os:ios",
32+
"@platforms//os:freebsd",
33+
"@platforms//os:android",
34+
]
35+
36+
def _platform_constraints_impl(ctx):
37+
return PlatformConstraintsInfo(value = [
38+
constraint.label
39+
for constraint in ctx.attr.constraints
40+
if ctx.target_platform_has_constraint(constraint[platform_common.ConstraintValueInfo])
41+
])
42+
43+
# The platform_constraints rule serves to populate a `PlatformConstraintsInfo` object for either the
44+
# exec or target platform.
45+
_platform_constraints = rule(
46+
implementation = _platform_constraints_impl,
47+
attrs = {
48+
"constraints": attr.label_list(
49+
mandatory = True,
50+
providers = [platform_common.ConstraintValueInfo],
51+
),
52+
},
53+
)
54+
55+
def _run_binary_impl(ctx):
56+
exec_constraints = ctx.attr.exec_platform_constraints[PlatformConstraintsInfo].value
57+
target_constraints = ctx.attr.target_platform_constraints[PlatformConstraintsInfo].value
58+
tool_cfg = "target" if exec_constraints == target_constraints else "exec"
59+
tool_attr_name = "tool_" + tool_cfg
60+
tool_as_list = [getattr(ctx.attr, tool_attr_name)]
61+
62+
# The implementation below is derived from bazel-skylib.
63+
args = [
64+
ctx.expand_location(a, tool_as_list)
65+
for a in ctx.attr.args
66+
]
67+
envs = {
68+
k: ctx.expand_location(v, tool_as_list)
69+
for k, v in ctx.attr.env.items()
70+
}
71+
ctx.actions.run(
72+
outputs = ctx.outputs.outs,
73+
inputs = ctx.files.srcs,
74+
tools = [getattr(ctx.executable, tool_attr_name)],
75+
executable = getattr(ctx.executable, tool_attr_name),
76+
arguments = args,
77+
mnemonic = "RunBinary",
78+
use_default_shell_env = False,
79+
env = ctx.configuration.default_shell_env | envs,
80+
)
81+
return DefaultInfo(
82+
files = depset(ctx.outputs.outs),
83+
runfiles = ctx.runfiles(files = ctx.outputs.outs),
84+
)
85+
86+
_run_binary = rule(
87+
implementation = _run_binary_impl,
88+
attrs = {
89+
"tool_exec": attr.label(
90+
executable = True,
91+
allow_files = True,
92+
mandatory = True,
93+
cfg = "exec",
94+
),
95+
"tool_target": attr.label(
96+
executable = True,
97+
allow_files = True,
98+
mandatory = True,
99+
cfg = "target",
100+
),
101+
"env": attr.string_dict(),
102+
"srcs": attr.label_list(
103+
allow_files = True,
104+
),
105+
"outs": attr.output_list(
106+
mandatory = True,
107+
),
108+
"args": attr.string_list(),
109+
"exec_platform_constraints": attr.label(
110+
mandatory = True,
111+
cfg = "exec",
112+
providers = [PlatformConstraintsInfo],
113+
),
114+
"target_platform_constraints": attr.label(
115+
mandatory = True,
116+
cfg = "target",
117+
providers = [PlatformConstraintsInfo],
118+
),
119+
},
120+
)
121+
122+
def run_binary(
123+
name,
124+
tool,
125+
env = {},
126+
srcs = [],
127+
outs = [],
128+
args = [],
129+
platform_constraints = PLATFORM_CONSTRAINTS,
130+
**kwargs):
131+
"""Runs a binary as a build step, avoiding separate exec build if possible.
132+
133+
Equivalent to the `run_binary` rule from bazel-skylib, except that the
134+
`tool` is built in the *target* configuration rather than the *exec*
135+
configuration if the target platform is the same as the exec platform.
136+
137+
In builds where many of the dependencies of `tool` are also needed in the
138+
`target` configuration, this avoids building those dependencies twice.
139+
140+
Args:
141+
name: Rule name
142+
tool: Label specifying the tool binary to run.
143+
env: Additional environment variables to set.
144+
srcs: Source files required.
145+
outs: Outputs.
146+
args: Command-line arguments for the tool.
147+
platform_constraints: List of platform constraints to check compatibility
148+
between exec and target configuration.
149+
**kwargs: Additional attributes common to all rules.
150+
"""
151+
constraints_rule_name = name + "_platform_constraints"
152+
_platform_constraints(
153+
name = constraints_rule_name,
154+
constraints = platform_constraints,
155+
visibility = ["//visibility:private"],
156+
)
157+
_run_binary(
158+
name = name,
159+
tool_exec = tool,
160+
tool_target = tool,
161+
exec_platform_constraints = constraints_rule_name,
162+
target_platform_constraints = constraints_rule_name,
163+
env = env,
164+
srcs = srcs,
165+
outs = outs,
166+
args = args,
167+
**kwargs
168+
)

py.typed

Whitespace-only changes.

pyproject.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ readme = "README.md"
1111
authors = [
1212
{ name = "TensorStore Team", email = "tensorstore-team@google.com" },
1313
]
14-
license = { file = "LICENSE" }
14+
license = "Apache-2.0"
1515
classifiers = [
1616
"Development Status :: 5 - Production/Stable",
17-
"License :: OSI Approved :: Apache Software License",
1817
"Topic :: Software Development :: Libraries",
1918
]
2019

@@ -28,13 +27,10 @@ requires = [
2827
"setuptools>=64",
2928
"wheel",
3029
"setuptools_scm>=8.1.0",
31-
"numpy>=2.0.0",
30+
"pip",
3231
]
3332
build-backend = "setuptools.build_meta"
3433

35-
[tool.cibuildwheel]
36-
build-frontend = "pip"
37-
3834
[tool.setuptools_scm]
3935
# It would be nice to include the commit hash in the version, but that
4036
# can't be done in a PEP 440-compatible way.

python/tensorstore/BUILD

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ load(
66
"pybind11_py_extension",
77
)
88
load("//bazel:pytest.bzl", "tensorstore_pytest_test")
9-
load("//bazel:pytype.bzl", "pytype_strict_binary")
9+
load("//bazel:pytype.bzl", "pytype_strict_binary", "pytype_strict_library")
10+
load("//bazel:run_binary_in_target_cfg_if_possible.bzl", "run_binary")
1011
load("//bazel:tensorstore.bzl", "tensorstore_cc_library")
1112
load("//docs:doctest.bzl", "doctest_test")
1213

@@ -21,6 +22,11 @@ exports_files([
2122
"cc_test_driver_main.py",
2223
])
2324

25+
PYI_FILES = [
26+
"__init__.pyi",
27+
"ocdbt.pyi",
28+
]
29+
2430
doctest_test(
2531
name = "doctest_test",
2632
srcs = glob([
@@ -67,6 +73,15 @@ pybind11_py_extension(
6773
],
6874
)
6975

76+
py_library(
77+
name = "_tensorstore_with_python_deps",
78+
deps = [
79+
":_tensorstore",
80+
"@pypa_ml_dtypes//:ml_dtypes",
81+
"@pypa_numpy//:numpy",
82+
],
83+
)
84+
7085
pybind11_cc_library(
7186
name = "tensorstore_module_components",
7287
srcs = ["tensorstore_module_components.cc"],
@@ -88,14 +103,13 @@ pybind11_cc_library(
88103
],
89104
)
90105

91-
py_library(
106+
pytype_strict_library(
92107
name = "core",
93108
srcs = ["__init__.py"],
109+
pytype_srcs = PYI_FILES,
94110
visibility = ["//visibility:public"],
95111
deps = [
96-
":_tensorstore",
97-
"@pypa_ml_dtypes//:ml_dtypes", # build_cleaner: keep
98-
"@pypa_numpy//:numpy",
112+
":_tensorstore_with_python_deps",
99113
],
100114
)
101115

@@ -1236,3 +1250,24 @@ tensorstore_pytest_test(
12361250
":tensorstore",
12371251
],
12381252
)
1253+
1254+
pytype_strict_binary(
1255+
name = "generate_type_stubs",
1256+
srcs = ["generate_type_stubs.py"],
1257+
data = ["__init__.py"],
1258+
main = "generate_type_stubs.py",
1259+
deps = [
1260+
":_tensorstore_with_python_deps", # build_cleaner: keep
1261+
"@pypa_pybind11_stubgen//:pybind11_stubgen",
1262+
],
1263+
)
1264+
1265+
run_binary(
1266+
name = "genrule_type_stubs",
1267+
outs = PYI_FILES,
1268+
args = [
1269+
"$(execpath __init__.pyi)",
1270+
"$(execpath ocdbt.pyi)",
1271+
],
1272+
tool = ":generate_type_stubs",
1273+
)

0 commit comments

Comments
 (0)