Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/torchcodec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import builtins
import os
from pathlib import Path

# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
Expand All @@ -25,3 +27,45 @@
# Similarly, these are exposed for downstream builds that use torchcodec as a
# dependency.
from ._core import core_library_path, ffmpeg_major_version # usort:skip


def _import_device_backends():
"""
Leverage the Python plugin mechanism to load out-of-the-tree device extensions.
"""
from importlib.metadata import entry_points

group_name = "torchcodec.backends"
backend_extensions = entry_points(group=group_name)

for backend_extension in backend_extensions:
try:
# Load the extension
entrypoint = backend_extension.load()
# Call the entrypoint
entrypoint()
except Exception as err:
raise RuntimeError(
f"Failed to load the backend extension: {backend_extension.name}. "
f"You can disable extension auto-loading with TORCHCODEC_DEVICE_BACKEND_AUTOLOAD=0."
) from err


def _is_device_backend_autoload_enabled() -> builtins.bool:
"""
Whether autoloading out-of-the-tree device extensions is enabled.
The switch depends on the value of the environment variable
`TORCHCODEC_DEVICE_BACKEND_AUTOLOAD`.

Returns:
bool: Whether to enable autoloading the extensions. Enabled by default.
"""
# enabled by default
return os.getenv("TORCHCODEC_DEVICE_BACKEND_AUTOLOAD", "1") == "1"


# `_import_device_backends` should be kept at the end to ensure
# all the other functions in this module that may be accessed by
# an autoloaded backend are defined
if _is_device_backend_autoload_enabled():
_import_device_backends()
12 changes: 12 additions & 0 deletions test/plugin/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "torchcodec-test-plugin"
description = "Test extension for torchcodec"
requires-python = ">=3.8"
dynamic = ["version"]

[project.entry-points.'torchcodec.backends']
device_backend = 'torchcodec_test_plugin:_autoload'
10 changes: 10 additions & 0 deletions test/plugin/torchcodec_test_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
This is a device backend extension used for testing.
"""

import os


def _autoload():
# Set the environment variable to true in this entrypoint
os.environ["IS_CUSTOM_DEVICE_BACKEND_IMPORTED"] = "1"
17 changes: 17 additions & 0 deletions test/test_autoload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os


def test_autoload():
switch = os.getenv("TORCHCODEC_DEVICE_BACKEND_AUTOLOAD", "0")

# After importing the test extension, the value of this environment variable should be true
is_imported = os.getenv("IS_CUSTOM_DEVICE_BACKEND_IMPORTED", "0")

# Both values should be equal
assert is_imported == switch
59 changes: 59 additions & 0 deletions test/test_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tre

import os
import subprocess
import sys
from pathlib import Path

import pytest


def _test_autoload(tmp_path, enable_autoload=True):
test_directory = Path(__file__).parent

# Build the test plugin
cmd = [
sys.executable,
"-m",
"pip",
"install",
"--root",
"./install",
test_directory / "plugin",
]
return_code = subprocess.run(cmd, cwd=tmp_path, env=os.environ)
assert return_code.returncode == 0

# "install" the test modules and run tests
python_path = os.environ.get("PYTHONPATH", "")
torchcodec_autoload = os.environ.get("PYTHONPATH", "")

try:
install_directory = ""

# install directory is the one that is named site-packages
for path in (tmp_path / "install").rglob("*"):
if path.is_dir() and "-packages" in path.name:
install_directory = str(path)

print(f">>>>> !!!! install_directory={install_directory}")
assert install_directory, "install_directory must not be empty"
os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
os.environ["TORCHCODEC_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable_autoload))

cmd = [sys.executable, "-m", "pytest", "test_autoload.py"]
return_code = subprocess.run(cmd, cwd=Path(__file__).parent, env=os.environ)
assert return_code.returncode == 0
finally:
os.environ["PYTHONPATH"] = python_path
if torchcodec_autoload != "":
os.environ.pop("TORCHCODEC_DEVICE_BACKEND_AUTOLOAD")


@pytest.mark.parametrize("enable_autoload", [True, False])
def test_plugin_autoload(tmp_path, enable_autoload):
return _test_autoload(tmp_path, enable_autoload=enable_autoload)
Loading