diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index 144d3a67f..824e61102 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import builtins +import os from pathlib import Path # Note: usort wants to put Frame and FrameBatch after decoders and samplers, @@ -25,3 +27,45 @@ # Similarly, these are exposed for downstream builds that use torchcodec as a # dependency. from ._core import core_library_path, ffmpeg_major_version # usort:skip + + +def _import_device_backends(): + """ + Leverage the Python plugin mechanism to load out-of-the-tree device extensions. + """ + from importlib.metadata import entry_points + + group_name = "torchcodec.backends" + backend_extensions = entry_points(group=group_name) + + for backend_extension in backend_extensions: + try: + # Load the extension + entrypoint = backend_extension.load() + # Call the entrypoint + entrypoint() + except Exception as err: + raise RuntimeError( + f"Failed to load the backend extension: {backend_extension.name}. " + f"You can disable extension auto-loading with TORCHCODEC_DEVICE_BACKEND_AUTOLOAD=0." + ) from err + + +def _is_device_backend_autoload_enabled() -> builtins.bool: + """ + Whether autoloading out-of-the-tree device extensions is enabled. + The switch depends on the value of the environment variable + `TORCHCODEC_DEVICE_BACKEND_AUTOLOAD`. + + Returns: + bool: Whether to enable autoloading the extensions. Enabled by default. + """ + # enabled by default + return os.getenv("TORCHCODEC_DEVICE_BACKEND_AUTOLOAD", "1") == "1" + + +# `_import_device_backends` should be kept at the end to ensure +# all the other functions in this module that may be accessed by +# an autoloaded backend are defined +if _is_device_backend_autoload_enabled(): + _import_device_backends() diff --git a/test/plugin/pyproject.toml b/test/plugin/pyproject.toml new file mode 100644 index 000000000..ec3c0375c --- /dev/null +++ b/test/plugin/pyproject.toml @@ -0,0 +1,12 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "torchcodec-test-plugin" +description = "Test extension for torchcodec" +requires-python = ">=3.8" +dynamic = ["version"] + +[project.entry-points.'torchcodec.backends'] +device_backend = 'torchcodec_test_plugin:_autoload' diff --git a/test/plugin/torchcodec_test_plugin/__init__.py b/test/plugin/torchcodec_test_plugin/__init__.py new file mode 100644 index 000000000..182e27f63 --- /dev/null +++ b/test/plugin/torchcodec_test_plugin/__init__.py @@ -0,0 +1,10 @@ +""" +This is a device backend extension used for testing. +""" + +import os + + +def _autoload(): + # Set the environment variable to true in this entrypoint + os.environ["IS_CUSTOM_DEVICE_BACKEND_IMPORTED"] = "1" diff --git a/test/test_autoload.py b/test/test_autoload.py new file mode 100644 index 000000000..28f14c3d2 --- /dev/null +++ b/test/test_autoload.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + + +def test_autoload(): + switch = os.getenv("TORCHCODEC_DEVICE_BACKEND_AUTOLOAD", "0") + + # After importing the test extension, the value of this environment variable should be true + is_imported = os.getenv("IS_CUSTOM_DEVICE_BACKEND_IMPORTED", "0") + + # Both values should be equal + assert is_imported == switch diff --git a/test/test_plugins.py b/test/test_plugins.py new file mode 100644 index 000000000..e69cd10e7 --- /dev/null +++ b/test/test_plugins.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tre + +import os +import subprocess +import sys +from pathlib import Path + +import pytest + + +def _test_autoload(tmp_path, enable_autoload=True): + test_directory = Path(__file__).parent + + # Build the test plugin + cmd = [ + sys.executable, + "-m", + "pip", + "install", + "--root", + "./install", + test_directory / "plugin", + ] + return_code = subprocess.run(cmd, cwd=tmp_path, env=os.environ) + assert return_code.returncode == 0 + + # "install" the test modules and run tests + python_path = os.environ.get("PYTHONPATH", "") + torchcodec_autoload = os.environ.get("PYTHONPATH", "") + + try: + install_directory = "" + + # install directory is the one that is named site-packages + for path in (tmp_path / "install").rglob("*"): + if path.is_dir() and "-packages" in path.name: + install_directory = str(path) + + print(f">>>>> !!!! install_directory={install_directory}") + assert install_directory, "install_directory must not be empty" + os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path]) + os.environ["TORCHCODEC_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable_autoload)) + + cmd = [sys.executable, "-m", "pytest", "test_autoload.py"] + return_code = subprocess.run(cmd, cwd=Path(__file__).parent, env=os.environ) + assert return_code.returncode == 0 + finally: + os.environ["PYTHONPATH"] = python_path + if torchcodec_autoload != "": + os.environ.pop("TORCHCODEC_DEVICE_BACKEND_AUTOLOAD") + + +@pytest.mark.parametrize("enable_autoload", [True, False]) +def test_plugin_autoload(tmp_path, enable_autoload): + return _test_autoload(tmp_path, enable_autoload=enable_autoload)