|
| 1 | +# Copyright 2025 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Test that public APIs are correctly documented.""" |
| 16 | + |
| 17 | +import collections |
| 18 | +from collections.abc import Iterator, Mapping, Sequence |
| 19 | +import importlib |
| 20 | +import functools |
| 21 | +import os |
| 22 | +import pkgutil |
| 23 | +import warnings |
| 24 | + |
| 25 | +from absl.testing import absltest |
| 26 | +from absl.testing import parameterized |
| 27 | + |
| 28 | +import jax |
| 29 | +import jax._src.test_util as jtu |
| 30 | +from jax._src import config |
| 31 | + |
| 32 | +config.parse_flags_with_absl() |
| 33 | + |
| 34 | + |
| 35 | +CURRENTMODULE_TAG = '.. currentmodule::' |
| 36 | +AUTOMODULE_TAG = '.. automodule::' |
| 37 | +AUTOSUMMARY_TAG = '.. autosummary::' |
| 38 | +AUTOCLASS_TAG = '.. autoclass::' |
| 39 | + |
| 40 | + |
| 41 | +@functools.lru_cache() |
| 42 | +def jax_docs_dir() -> str: |
| 43 | + """Return the string or path object pointing to the JAX docs.""" |
| 44 | + try: |
| 45 | + # In bazel, access docs files via data dependencies of a jax.docs package. |
| 46 | + return importlib.resources.files('jax.docs') |
| 47 | + except ImportError: |
| 48 | + # Outside of bazel, assume code is layed out as in the github repository, where |
| 49 | + # the docs and tests subdirectories are both within the same top-level directory. |
| 50 | + return os.path.abspath(os.path.join(__file__, os.pardir, os.pardir, "docs")) |
| 51 | + |
| 52 | + |
| 53 | +UNDOCUMENTED_APIS = { |
| 54 | + 'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'typeof', 'version'], |
| 55 | + 'jax.custom_batching': ['custom_vmap', 'sequential_vmap'], |
| 56 | + 'jax.custom_derivatives': ['CustomVJPPrimal', 'SymbolicZero', 'closure_convert', 'custom_gradient', 'custom_jvp', 'custom_jvp_call_p', 'custom_vjp', 'custom_vjp_call_p', 'custom_vjp_primal_tree_values', 'linear_call', 'remat_opt_p', 'zero_from_primal'], |
| 57 | + 'jax.custom_transpose': ['custom_transpose'], |
| 58 | + 'jax.debug': ['DebugEffect', 'log'], |
| 59 | + 'jax.distributed': ['is_initialized'], |
| 60 | + 'jax.dlpack': ['jax'], |
| 61 | + 'jax.dtypes': ['extended', 'finfo', 'iinfo'], |
| 62 | + 'jax.errors': ['JAXIndexError', 'JAXTypeError'], |
| 63 | + 'jax.ffi': ['build_ffi_lowering_function', 'include_dir', 'register_ffi_target_as_batch_partitionable', 'register_ffi_type_id'], |
| 64 | + 'jax.lax': ['all_gather_invariant', 'unreduced_psum', 'dce_sink', 'conv_transpose_shape_tuple', 'reduce_window_shape_tuple', 'preduced', 'conv_general_permutations', 'conv_general_shape_tuple', 'pbroadcast', 'padtype_to_pads', 'conv_shape_tuple', 'unreduced_psum_scatter', 'create_token', 'dtype', 'shape_as_value', 'all_gather_reduced', 'pvary', *(name for name in dir(jax.lax) if name.endswith('_p'))], |
| 65 | + 'jax.lax.linalg': [api for api in dir(jax.lax.linalg) if api.endswith('_p')], |
| 66 | + 'jax.memory': ['Space'], |
| 67 | + 'jax.monitoring': ['clear_event_listeners', 'record_event', 'record_event_duration_secs', 'record_event_time_span', 'record_scalar', 'register_event_duration_secs_listener', 'register_event_listener', 'register_event_time_span_listener', 'register_scalar_listener', 'unregister_event_duration_listener', 'unregister_event_listener', 'unregister_event_time_span_listener', 'unregister_scalar_listener'], |
| 68 | + 'jax.nn': ['tanh'], |
| 69 | + 'jax.nn.initializers': ['Initializer', 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform'], |
| 70 | + 'jax.numpy': ['bfloat16', 'bool', 'e', 'euler_gamma', 'float4_e2m1fn', 'float8_e3m4', 'float8_e4m3', 'float8_e4m3b11fnuz', 'float8_e4m3fn', 'float8_e4m3fnuz', 'float8_e5m2', 'float8_e5m2fnuz', 'float8_e8m0fnu', 'inf', 'int2', 'int4', 'nan', 'newaxis', 'pi', 'uint2', 'uint4'], |
| 71 | + 'jax.profiler': ['ProfileData', 'ProfileEvent', 'ProfileOptions', 'ProfilePlane', 'stop_server'], |
| 72 | + 'jax.random': ['key_impl', 'random_gamma_p'], |
| 73 | + 'jax.scipy.special': ['bessel_jn', 'sph_harm_y'], |
| 74 | + 'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh'], |
| 75 | + 'jax.stages': ['ArgInfo', 'CompilerOptions'], |
| 76 | + 'jax.tree_util': ['DictKey', 'FlattenedIndexKey', 'GetAttrKey', 'PyTreeDef', 'SequenceKey', 'default_registry'], |
| 77 | +} |
| 78 | + |
| 79 | +# A list of modules to skip entirely, either because they cannot be imported |
| 80 | +# or because they are not expected to be documented. |
| 81 | +MODULES_TO_SKIP = [ |
| 82 | + "jax.ad_checkpoint", # internal tools, not documented. |
| 83 | + "jax.api_util", # internal tools, not documented. |
| 84 | + "jax.cloud_tpu_init", # deprecated in JAX v0.8.1 |
| 85 | + "jax.collect_profile", # fails when xprof is not available. |
| 86 | + "jax.core", # internal tools, not documented. |
| 87 | + "jax.example_libraries", # TODO(jakevdp): un-skip these. |
| 88 | + "jax.extend", # TODO(jakevdp): un-skip these. |
| 89 | + "jax.experimental", # Many non-public submodules. |
| 90 | + "jax.interpreters", # internal tools, not documented. |
| 91 | + "jax.jaxlib", # internal tools, not documented. |
| 92 | + "jax.lib", # deprecated in JAX v0.8.0 |
| 93 | + "jax.tools", # internal tools, not documented. |
| 94 | + "jax.version", # no public APIs. |
| 95 | +] |
| 96 | + |
| 97 | + |
| 98 | +def extract_apis_from_rst_file(path: str) -> dict[str, list[str]]: |
| 99 | + """Extract documented APIs from an RST file.""" |
| 100 | + # We could do this more robustly by adding a docutils dependency, but that is |
| 101 | + # pretty heavy. Instead we use simple string-based file parsing, recognizing the |
| 102 | + # particular patterns used within the JAX documentation. |
| 103 | + currentmodule: str = '<none>' |
| 104 | + in_autosummary_block = False |
| 105 | + apis = collections.defaultdict(list) |
| 106 | + with open(path, 'r') as f: |
| 107 | + for line in f: |
| 108 | + stripped_line = line.strip() |
| 109 | + if not stripped_line: |
| 110 | + continue |
| 111 | + if line.startswith(CURRENTMODULE_TAG): |
| 112 | + currentmodule = line.removeprefix(CURRENTMODULE_TAG).strip() |
| 113 | + continue |
| 114 | + if line.startswith(AUTOMODULE_TAG): |
| 115 | + currentmodule = line.removeprefix(AUTOMODULE_TAG).strip() |
| 116 | + continue |
| 117 | + if line.startswith(AUTOCLASS_TAG): |
| 118 | + in_autosummary_block = False |
| 119 | + apis[currentmodule].append(line.removeprefix(AUTOCLASS_TAG).strip()) |
| 120 | + continue |
| 121 | + if line.startswith(AUTOSUMMARY_TAG): |
| 122 | + in_autosummary_block = True |
| 123 | + continue |
| 124 | + if not in_autosummary_block: |
| 125 | + continue |
| 126 | + if not line.startswith(' '): |
| 127 | + in_autosummary_block = False |
| 128 | + continue |
| 129 | + if stripped_line.startswith(':'): |
| 130 | + continue |
| 131 | + apis[currentmodule].append(stripped_line) |
| 132 | + return dict(apis) |
| 133 | + |
| 134 | + |
| 135 | +@functools.lru_cache() |
| 136 | +def get_all_documented_jax_apis() -> Mapping[str, list[str]]: |
| 137 | + """Get the list of APIs documented in all files in a directory (recursive).""" |
| 138 | + path = jax_docs_dir() |
| 139 | + |
| 140 | + apis = collections.defaultdict(list) |
| 141 | + for root, _, files in os.walk(path): |
| 142 | + if (root.startswith(os.path.join(path, 'build')) |
| 143 | + or root.startswith(os.path.join(path, '_autosummary'))): |
| 144 | + continue |
| 145 | + for filename in files: |
| 146 | + if filename.endswith('.rst'): |
| 147 | + new_apis = extract_apis_from_rst_file(os.path.join(root, filename)) |
| 148 | + for key, val in new_apis.items(): |
| 149 | + apis[key].extend(val) |
| 150 | + return {key: sorted(vals) for key, vals in apis.items()} |
| 151 | + |
| 152 | + |
| 153 | +@functools.lru_cache() |
| 154 | +def list_public_jax_modules() -> Sequence[str]: |
| 155 | + """Return a list of the public modules defined in jax.""" |
| 156 | + # We could use pkgutil.walk_packages, but we want to avoid traversing modules |
| 157 | + # like `jax._src`, `jax.example_libraries`, etc. so we implement it manually. |
| 158 | + def walk_public_modules(paths: list[str], parent_package: str) -> Iterator[str]: |
| 159 | + for info in pkgutil.iter_modules(paths): |
| 160 | + pkg_name = f"{parent_package}.{info.name}" |
| 161 | + if pkg_name in MODULES_TO_SKIP or info.name == 'tests' or info.name.startswith('_'): |
| 162 | + continue |
| 163 | + yield pkg_name |
| 164 | + if not info.ispkg: |
| 165 | + continue |
| 166 | + try: |
| 167 | + submodule = importlib.import_module(pkg_name) |
| 168 | + except ImportError as e: |
| 169 | + warnings.warn(f"failed to import {pkg_name}: {e!r}") |
| 170 | + else: |
| 171 | + if path := getattr(submodule, '__path__', None): |
| 172 | + yield from walk_public_modules(path, pkg_name) |
| 173 | + return [jax.__name__, *walk_public_modules(jax.__path__, jax.__name__)] |
| 174 | + |
| 175 | + |
| 176 | +@functools.lru_cache() |
| 177 | +def list_public_apis(module_name: str) -> Sequence[str]: |
| 178 | + """Return a list of public APIs within a specified module. |
| 179 | +
|
| 180 | + This will import the module as a side-effect. |
| 181 | + """ |
| 182 | + module = importlib.import_module(module_name) |
| 183 | + return [api for api in dir(module) |
| 184 | + if not api.startswith('_') # skip private members |
| 185 | + and not api.startswith('@') # skip injected pytest-related symbols |
| 186 | + ] |
| 187 | + |
| 188 | + |
| 189 | +@functools.lru_cache() |
| 190 | +def get_all_public_jax_apis() -> Mapping[str, list[str]]: |
| 191 | + """Return a dictionary mapping jax submodules to their list of public APIs.""" |
| 192 | + apis = {} |
| 193 | + for module in list_public_jax_modules(): |
| 194 | + try: |
| 195 | + apis[module] = list_public_apis(module) |
| 196 | + except ImportError as e: |
| 197 | + warnings.warn(f"failed to import {module}: {e}") |
| 198 | + return apis |
| 199 | + |
| 200 | + |
| 201 | +class DocumentationCoverageTest(jtu.JaxTestCase): |
| 202 | + |
| 203 | + def setUp(self): |
| 204 | + if jtu.runtime_environment() == 'bazel': |
| 205 | + self.skipTest("Skipping test in bazel, because rst docs aren't accessible.") |
| 206 | + |
| 207 | + def test_list_public_jax_modules(self): |
| 208 | + """Simple smoke test for list_public_jax_modules()""" |
| 209 | + apis = list_public_jax_modules() |
| 210 | + |
| 211 | + # A few submodules which should be included |
| 212 | + self.assertIn("jax", apis) |
| 213 | + self.assertIn("jax.numpy", apis) |
| 214 | + self.assertIn("jax.numpy.linalg", apis) |
| 215 | + |
| 216 | + # A few submodules which should not be included |
| 217 | + self.assertNotIn("jax._src", apis) |
| 218 | + self.assertNotIn("jax._src.numpy", apis) |
| 219 | + self.assertNotIn("jax.example_libraries", apis) |
| 220 | + self.assertNotIn("jax.experimental.jax2tf", apis) |
| 221 | + |
| 222 | + def test_list_public_apis(self): |
| 223 | + """Simple smoketest for list_public_apis()""" |
| 224 | + jnp_apis = list_public_apis('jax.numpy') |
| 225 | + self.assertIn("array", jnp_apis) |
| 226 | + self.assertIn("zeros", jnp_apis) |
| 227 | + self.assertNotIn("jax.numpy.array", jnp_apis) |
| 228 | + self.assertNotIn("np", jnp_apis) |
| 229 | + self.assertNotIn("jax", jnp_apis) |
| 230 | + |
| 231 | + def test_get_all_public_jax_apis(self): |
| 232 | + """Simple smoketest for get_all_public_jax_apis()""" |
| 233 | + apis = get_all_public_jax_apis() |
| 234 | + self.assertIn("Array", apis["jax"]) |
| 235 | + self.assertIn("array", apis["jax.numpy"]) |
| 236 | + self.assertIn("eigh", apis["jax.numpy.linalg"]) |
| 237 | + |
| 238 | + def test_extract_apis_from_rst_file(self): |
| 239 | + """Simple smoketest for extract_apis_from_rst_file()""" |
| 240 | + numpy_docs = os.path.join(jax_docs_dir(), "jax.numpy.rst") |
| 241 | + apis = extract_apis_from_rst_file(numpy_docs) |
| 242 | + |
| 243 | + self.assertIn("jax.numpy", apis.keys()) |
| 244 | + self.assertIn("jax.numpy.linalg", apis.keys()) |
| 245 | + |
| 246 | + self.assertIn("array", apis["jax.numpy"]) |
| 247 | + self.assertIn("asarray", apis["jax.numpy"]) |
| 248 | + self.assertIn("eigh", apis["jax.numpy.linalg"]) |
| 249 | + self.assertNotIn("jax", apis["jax.numpy"]) |
| 250 | + self.assertNotIn("jax.numpy", apis["jax.numpy"]) |
| 251 | + |
| 252 | + def test_get_all_documented_jax_apis(self): |
| 253 | + """Simple smoketest of get_all_documented_jax_apis()""" |
| 254 | + apis = get_all_documented_jax_apis() |
| 255 | + self.assertIn("Array", apis["jax"]) |
| 256 | + self.assertIn("arange", apis["jax.numpy"]) |
| 257 | + self.assertIn("eigh", apis["jax.lax.linalg"]) |
| 258 | + |
| 259 | + @parameterized.parameters(list_public_jax_modules()) |
| 260 | + def test_module_apis_documented(self, module): |
| 261 | + """Test that the APIs in each module are appropriately documented.""" |
| 262 | + public_apis = get_all_public_jax_apis() |
| 263 | + documented_apis = get_all_documented_jax_apis() |
| 264 | + |
| 265 | + pub_apis = {f"{module}.{api}" for api in public_apis.get(module, ())} |
| 266 | + doc_apis = {f"{module}.{api}" for api in documented_apis.get(module, ())} |
| 267 | + undoc_apis = {f"{module}.{api}" for api in UNDOCUMENTED_APIS.get(module, ())} |
| 268 | + |
| 269 | + # Remove submodules from list. |
| 270 | + pub_apis -= public_apis.keys() |
| 271 | + pub_apis -= set(MODULES_TO_SKIP) |
| 272 | + |
| 273 | + if (notempty := undoc_apis & doc_apis): |
| 274 | + raise ValueError( |
| 275 | + f"Found stale values in the UNDOCUMENTED_APIS list: {notempty}." |
| 276 | + " If this fails, the fix is typically to remove the offending entries" |
| 277 | + " from the UNDOCUMENTED_APIS mapping.") |
| 278 | + |
| 279 | + if (notempty := pub_apis - doc_apis - undoc_apis): |
| 280 | + raise ValueError( |
| 281 | + f"Found public APIs that are not listed within docs: {notempty}." |
| 282 | + " If this fails, it likely means a new public API has been added to the" |
| 283 | + " jax package without an associated entry in docs/*.rst. To fix this," |
| 284 | + " either add the missing documentation entries, or add these names to the" |
| 285 | + " UNDOCUMENTED_APIS mapping to indicate it is deliberately undocumented.") |
| 286 | + |
| 287 | + |
| 288 | +if __name__ == "__main__": |
| 289 | + absltest.main(testLoader=jtu.JaxTestLoader()) |
0 commit comments