Skip to content

Commit 30ebb4f

Browse files
Merge pull request #33316 from jakevdp:doc-coverage-test
PiperOrigin-RevId: 833417531
2 parents ad960fd + cd11231 commit 30ebb4f

File tree

4 files changed

+326
-0
lines changed

4 files changed

+326
-0
lines changed

docs/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

jax/_src/test_util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,3 +2409,13 @@ def setup_hypothesis(max_examples=30) -> None:
24092409
profile = HYPOTHESIS_PROFILE.value
24102410
logging.info("Using hypothesis profile: %s", profile)
24112411
hp.settings.load_profile(profile)
2412+
2413+
2414+
def runtime_environment() -> str | None:
2415+
"""Returns None, "bazel" or "pytest"."""
2416+
if sys.executable is None:
2417+
return None
2418+
elif 'bazel-out' in sys.executable:
2419+
return "bazel"
2420+
else:
2421+
return "pytest"

tests/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,20 @@ jax_py_test(
957957
] + py_deps("absl/testing"),
958958
)
959959

960+
jax_py_test(
961+
name = "documentation_coverage_test",
962+
srcs = [
963+
"documentation_coverage_test.py",
964+
],
965+
deps = [
966+
"//jax",
967+
"//jax/_src:config",
968+
"//jax/_src:internal_test_util",
969+
"//jax/_src:test_util",
970+
# "//jax/docs",
971+
] + py_deps("absl/testing"),
972+
)
973+
960974
jax_multiplatform_test(
961975
name = "linalg_test",
962976
srcs = ["linalg_test.py"],
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test that public APIs are correctly documented."""
16+
17+
import collections
18+
from collections.abc import Iterator, Mapping, Sequence
19+
import importlib
20+
import functools
21+
import os
22+
import pkgutil
23+
import warnings
24+
25+
from absl.testing import absltest
26+
from absl.testing import parameterized
27+
28+
import jax
29+
import jax._src.test_util as jtu
30+
from jax._src import config
31+
32+
config.parse_flags_with_absl()
33+
34+
35+
CURRENTMODULE_TAG = '.. currentmodule::'
36+
AUTOMODULE_TAG = '.. automodule::'
37+
AUTOSUMMARY_TAG = '.. autosummary::'
38+
AUTOCLASS_TAG = '.. autoclass::'
39+
40+
41+
@functools.lru_cache()
42+
def jax_docs_dir() -> str:
43+
"""Return the string or path object pointing to the JAX docs."""
44+
try:
45+
# In bazel, access docs files via data dependencies of a jax.docs package.
46+
return importlib.resources.files('jax.docs')
47+
except ImportError:
48+
# Outside of bazel, assume code is layed out as in the github repository, where
49+
# the docs and tests subdirectories are both within the same top-level directory.
50+
return os.path.abspath(os.path.join(__file__, os.pardir, os.pardir, "docs"))
51+
52+
53+
UNDOCUMENTED_APIS = {
54+
'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'typeof', 'version'],
55+
'jax.custom_batching': ['custom_vmap', 'sequential_vmap'],
56+
'jax.custom_derivatives': ['CustomVJPPrimal', 'SymbolicZero', 'closure_convert', 'custom_gradient', 'custom_jvp', 'custom_jvp_call_p', 'custom_vjp', 'custom_vjp_call_p', 'custom_vjp_primal_tree_values', 'linear_call', 'remat_opt_p', 'zero_from_primal'],
57+
'jax.custom_transpose': ['custom_transpose'],
58+
'jax.debug': ['DebugEffect', 'log'],
59+
'jax.distributed': ['is_initialized'],
60+
'jax.dlpack': ['jax'],
61+
'jax.dtypes': ['extended', 'finfo', 'iinfo'],
62+
'jax.errors': ['JAXIndexError', 'JAXTypeError'],
63+
'jax.ffi': ['build_ffi_lowering_function', 'include_dir', 'register_ffi_target_as_batch_partitionable', 'register_ffi_type_id'],
64+
'jax.lax': ['all_gather_invariant', 'unreduced_psum', 'dce_sink', 'conv_transpose_shape_tuple', 'reduce_window_shape_tuple', 'preduced', 'conv_general_permutations', 'conv_general_shape_tuple', 'pbroadcast', 'padtype_to_pads', 'conv_shape_tuple', 'unreduced_psum_scatter', 'create_token', 'dtype', 'shape_as_value', 'all_gather_reduced', 'pvary', *(name for name in dir(jax.lax) if name.endswith('_p'))],
65+
'jax.lax.linalg': [api for api in dir(jax.lax.linalg) if api.endswith('_p')],
66+
'jax.memory': ['Space'],
67+
'jax.monitoring': ['clear_event_listeners', 'record_event', 'record_event_duration_secs', 'record_event_time_span', 'record_scalar', 'register_event_duration_secs_listener', 'register_event_listener', 'register_event_time_span_listener', 'register_scalar_listener', 'unregister_event_duration_listener', 'unregister_event_listener', 'unregister_event_time_span_listener', 'unregister_scalar_listener'],
68+
'jax.nn': ['tanh'],
69+
'jax.nn.initializers': ['Initializer', 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform'],
70+
'jax.numpy': ['bfloat16', 'bool', 'e', 'euler_gamma', 'float4_e2m1fn', 'float8_e3m4', 'float8_e4m3', 'float8_e4m3b11fnuz', 'float8_e4m3fn', 'float8_e4m3fnuz', 'float8_e5m2', 'float8_e5m2fnuz', 'float8_e8m0fnu', 'inf', 'int2', 'int4', 'nan', 'newaxis', 'pi', 'uint2', 'uint4'],
71+
'jax.profiler': ['ProfileData', 'ProfileEvent', 'ProfileOptions', 'ProfilePlane', 'stop_server'],
72+
'jax.random': ['key_impl', 'random_gamma_p'],
73+
'jax.scipy.special': ['bessel_jn', 'sph_harm_y'],
74+
'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh'],
75+
'jax.stages': ['ArgInfo', 'CompilerOptions'],
76+
'jax.tree_util': ['DictKey', 'FlattenedIndexKey', 'GetAttrKey', 'PyTreeDef', 'SequenceKey', 'default_registry'],
77+
}
78+
79+
# A list of modules to skip entirely, either because they cannot be imported
80+
# or because they are not expected to be documented.
81+
MODULES_TO_SKIP = [
82+
"jax.ad_checkpoint", # internal tools, not documented.
83+
"jax.api_util", # internal tools, not documented.
84+
"jax.cloud_tpu_init", # deprecated in JAX v0.8.1
85+
"jax.collect_profile", # fails when xprof is not available.
86+
"jax.core", # internal tools, not documented.
87+
"jax.example_libraries", # TODO(jakevdp): un-skip these.
88+
"jax.extend", # TODO(jakevdp): un-skip these.
89+
"jax.experimental", # Many non-public submodules.
90+
"jax.interpreters", # internal tools, not documented.
91+
"jax.jaxlib", # internal tools, not documented.
92+
"jax.lib", # deprecated in JAX v0.8.0
93+
"jax.tools", # internal tools, not documented.
94+
"jax.version", # no public APIs.
95+
]
96+
97+
98+
def extract_apis_from_rst_file(path: str) -> dict[str, list[str]]:
99+
"""Extract documented APIs from an RST file."""
100+
# We could do this more robustly by adding a docutils dependency, but that is
101+
# pretty heavy. Instead we use simple string-based file parsing, recognizing the
102+
# particular patterns used within the JAX documentation.
103+
currentmodule: str = '<none>'
104+
in_autosummary_block = False
105+
apis = collections.defaultdict(list)
106+
with open(path, 'r') as f:
107+
for line in f:
108+
stripped_line = line.strip()
109+
if not stripped_line:
110+
continue
111+
if line.startswith(CURRENTMODULE_TAG):
112+
currentmodule = line.removeprefix(CURRENTMODULE_TAG).strip()
113+
continue
114+
if line.startswith(AUTOMODULE_TAG):
115+
currentmodule = line.removeprefix(AUTOMODULE_TAG).strip()
116+
continue
117+
if line.startswith(AUTOCLASS_TAG):
118+
in_autosummary_block = False
119+
apis[currentmodule].append(line.removeprefix(AUTOCLASS_TAG).strip())
120+
continue
121+
if line.startswith(AUTOSUMMARY_TAG):
122+
in_autosummary_block = True
123+
continue
124+
if not in_autosummary_block:
125+
continue
126+
if not line.startswith(' '):
127+
in_autosummary_block = False
128+
continue
129+
if stripped_line.startswith(':'):
130+
continue
131+
apis[currentmodule].append(stripped_line)
132+
return dict(apis)
133+
134+
135+
@functools.lru_cache()
136+
def get_all_documented_jax_apis() -> Mapping[str, list[str]]:
137+
"""Get the list of APIs documented in all files in a directory (recursive)."""
138+
path = jax_docs_dir()
139+
140+
apis = collections.defaultdict(list)
141+
for root, _, files in os.walk(path):
142+
if (root.startswith(os.path.join(path, 'build'))
143+
or root.startswith(os.path.join(path, '_autosummary'))):
144+
continue
145+
for filename in files:
146+
if filename.endswith('.rst'):
147+
new_apis = extract_apis_from_rst_file(os.path.join(root, filename))
148+
for key, val in new_apis.items():
149+
apis[key].extend(val)
150+
return {key: sorted(vals) for key, vals in apis.items()}
151+
152+
153+
@functools.lru_cache()
154+
def list_public_jax_modules() -> Sequence[str]:
155+
"""Return a list of the public modules defined in jax."""
156+
# We could use pkgutil.walk_packages, but we want to avoid traversing modules
157+
# like `jax._src`, `jax.example_libraries`, etc. so we implement it manually.
158+
def walk_public_modules(paths: list[str], parent_package: str) -> Iterator[str]:
159+
for info in pkgutil.iter_modules(paths):
160+
pkg_name = f"{parent_package}.{info.name}"
161+
if pkg_name in MODULES_TO_SKIP or info.name == 'tests' or info.name.startswith('_'):
162+
continue
163+
yield pkg_name
164+
if not info.ispkg:
165+
continue
166+
try:
167+
submodule = importlib.import_module(pkg_name)
168+
except ImportError as e:
169+
warnings.warn(f"failed to import {pkg_name}: {e!r}")
170+
else:
171+
if path := getattr(submodule, '__path__', None):
172+
yield from walk_public_modules(path, pkg_name)
173+
return [jax.__name__, *walk_public_modules(jax.__path__, jax.__name__)]
174+
175+
176+
@functools.lru_cache()
177+
def list_public_apis(module_name: str) -> Sequence[str]:
178+
"""Return a list of public APIs within a specified module.
179+
180+
This will import the module as a side-effect.
181+
"""
182+
module = importlib.import_module(module_name)
183+
return [api for api in dir(module)
184+
if not api.startswith('_') # skip private members
185+
and not api.startswith('@') # skip injected pytest-related symbols
186+
]
187+
188+
189+
@functools.lru_cache()
190+
def get_all_public_jax_apis() -> Mapping[str, list[str]]:
191+
"""Return a dictionary mapping jax submodules to their list of public APIs."""
192+
apis = {}
193+
for module in list_public_jax_modules():
194+
try:
195+
apis[module] = list_public_apis(module)
196+
except ImportError as e:
197+
warnings.warn(f"failed to import {module}: {e}")
198+
return apis
199+
200+
201+
class DocumentationCoverageTest(jtu.JaxTestCase):
202+
203+
def setUp(self):
204+
if jtu.runtime_environment() == 'bazel':
205+
self.skipTest("Skipping test in bazel, because rst docs aren't accessible.")
206+
207+
def test_list_public_jax_modules(self):
208+
"""Simple smoke test for list_public_jax_modules()"""
209+
apis = list_public_jax_modules()
210+
211+
# A few submodules which should be included
212+
self.assertIn("jax", apis)
213+
self.assertIn("jax.numpy", apis)
214+
self.assertIn("jax.numpy.linalg", apis)
215+
216+
# A few submodules which should not be included
217+
self.assertNotIn("jax._src", apis)
218+
self.assertNotIn("jax._src.numpy", apis)
219+
self.assertNotIn("jax.example_libraries", apis)
220+
self.assertNotIn("jax.experimental.jax2tf", apis)
221+
222+
def test_list_public_apis(self):
223+
"""Simple smoketest for list_public_apis()"""
224+
jnp_apis = list_public_apis('jax.numpy')
225+
self.assertIn("array", jnp_apis)
226+
self.assertIn("zeros", jnp_apis)
227+
self.assertNotIn("jax.numpy.array", jnp_apis)
228+
self.assertNotIn("np", jnp_apis)
229+
self.assertNotIn("jax", jnp_apis)
230+
231+
def test_get_all_public_jax_apis(self):
232+
"""Simple smoketest for get_all_public_jax_apis()"""
233+
apis = get_all_public_jax_apis()
234+
self.assertIn("Array", apis["jax"])
235+
self.assertIn("array", apis["jax.numpy"])
236+
self.assertIn("eigh", apis["jax.numpy.linalg"])
237+
238+
def test_extract_apis_from_rst_file(self):
239+
"""Simple smoketest for extract_apis_from_rst_file()"""
240+
numpy_docs = os.path.join(jax_docs_dir(), "jax.numpy.rst")
241+
apis = extract_apis_from_rst_file(numpy_docs)
242+
243+
self.assertIn("jax.numpy", apis.keys())
244+
self.assertIn("jax.numpy.linalg", apis.keys())
245+
246+
self.assertIn("array", apis["jax.numpy"])
247+
self.assertIn("asarray", apis["jax.numpy"])
248+
self.assertIn("eigh", apis["jax.numpy.linalg"])
249+
self.assertNotIn("jax", apis["jax.numpy"])
250+
self.assertNotIn("jax.numpy", apis["jax.numpy"])
251+
252+
def test_get_all_documented_jax_apis(self):
253+
"""Simple smoketest of get_all_documented_jax_apis()"""
254+
apis = get_all_documented_jax_apis()
255+
self.assertIn("Array", apis["jax"])
256+
self.assertIn("arange", apis["jax.numpy"])
257+
self.assertIn("eigh", apis["jax.lax.linalg"])
258+
259+
@parameterized.parameters(list_public_jax_modules())
260+
def test_module_apis_documented(self, module):
261+
"""Test that the APIs in each module are appropriately documented."""
262+
public_apis = get_all_public_jax_apis()
263+
documented_apis = get_all_documented_jax_apis()
264+
265+
pub_apis = {f"{module}.{api}" for api in public_apis.get(module, ())}
266+
doc_apis = {f"{module}.{api}" for api in documented_apis.get(module, ())}
267+
undoc_apis = {f"{module}.{api}" for api in UNDOCUMENTED_APIS.get(module, ())}
268+
269+
# Remove submodules from list.
270+
pub_apis -= public_apis.keys()
271+
pub_apis -= set(MODULES_TO_SKIP)
272+
273+
if (notempty := undoc_apis & doc_apis):
274+
raise ValueError(
275+
f"Found stale values in the UNDOCUMENTED_APIS list: {notempty}."
276+
" If this fails, the fix is typically to remove the offending entries"
277+
" from the UNDOCUMENTED_APIS mapping.")
278+
279+
if (notempty := pub_apis - doc_apis - undoc_apis):
280+
raise ValueError(
281+
f"Found public APIs that are not listed within docs: {notempty}."
282+
" If this fails, it likely means a new public API has been added to the"
283+
" jax package without an associated entry in docs/*.rst. To fix this,"
284+
" either add the missing documentation entries, or add these names to the"
285+
" UNDOCUMENTED_APIS mapping to indicate it is deliberately undocumented.")
286+
287+
288+
if __name__ == "__main__":
289+
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)