Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,6 @@ pyvenv.cfg

# vim
*.swp

# uv
/uv.lock
9 changes: 9 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# Pre-release

- Added type checking and automatic linting/formatting, https://github.com/open-quantum-safe/liboqs-python/pull/97
- Added a utility function for de-structuring version strings in `oqs.py`
- `version(version_str: str) -> tuple[str, str, str]:` - Returns a tuple
containing the (major, minor, patch) versions
- A warning is issued only if the liboqs-python version's major and minor
numbers differ from those of liboqs, ignoring the patch version

# Version 0.12.0 - January 15, 2025

- Fixes https://github.com/open-quantum-safe/liboqs-python/issues/98. The API
Expand Down
86 changes: 61 additions & 25 deletions oqs/oqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(stdout))

# Expected return value from native OQS functions
OQS_SUCCESS: Final[int] = 0
OQS_ERROR: Final[int] = -1


def oqs_python_version() -> Union[str, None]:
"""liboqs-python version string."""
Expand All @@ -50,12 +54,14 @@ def oqs_python_version() -> Union[str, None]:
OQS_VERSION = oqs_python_version()


def _countdown(seconds: int) -> None:
while seconds > 0:
logger.info("Installing in %s seconds...", seconds)
stdout.flush()
seconds -= 1
time.sleep(1)
def version(version_str: str) -> tuple[str, str, str]:
parts = version_str.split(".")

major = parts[0] if len(parts) > 0 else ""
minor = parts[1] if len(parts) > 1 else ""
patch = parts[2] if len(parts) > 2 else ""

return major, minor, patch


def _load_shared_obj(
Expand Down Expand Up @@ -100,6 +106,14 @@ def _load_shared_obj(
raise RuntimeError(msg)


def _countdown(seconds: int) -> None:
while seconds > 0:
logger.info("Installing in %s seconds...", seconds)
stdout.flush()
seconds -= 1
time.sleep(1)


def _install_liboqs(
target_directory: Path,
oqs_version_to_install: Union[str, None] = None,
Expand Down Expand Up @@ -188,7 +202,9 @@ def _load_liboqs() -> ct.CDLL:
assert liboqs # noqa: S101
except RuntimeError:
# We don't have liboqs, so we try to install it automatically
_install_liboqs(target_directory=oqs_install_dir, oqs_version_to_install=OQS_VERSION)
_install_liboqs(
target_directory=oqs_install_dir, oqs_version_to_install=OQS_VERSION
)
# Try loading it again
try:
liboqs = _load_shared_obj(
Expand All @@ -206,11 +222,6 @@ def _load_liboqs() -> ct.CDLL:
_liboqs = _load_liboqs()


# Expected return value from native OQS functions
OQS_SUCCESS: Final[int] = 0
OQS_ERROR: Final[int] = -1


def native() -> ct.CDLL:
"""Handle to native liboqs handler."""
return _liboqs
Expand All @@ -226,13 +237,24 @@ def oqs_version() -> str:
return ct.c_char_p(native().OQS_version()).value.decode("UTF-8") # type: ignore[union-attr]


# Warn the user if the liboqs version differs from liboqs-python version
if oqs_version() != oqs_python_version():
warnings.warn(
f"liboqs version {oqs_version()} differs from liboqs-python version "
f"{oqs_python_version()}",
stacklevel=2,
oqs_ver = oqs_version()
oqs_ver_major, oqs_ver_minor, oqs_ver_patch = version(oqs_ver)


oqs_python_ver = oqs_python_version()
if oqs_python_ver:
oqs_python_ver_major, oqs_python_ver_minor, oqs_python_ver_patch = version(
oqs_python_ver
)
# Warn the user if the liboqs version differs from liboqs-python version
if not (
oqs_ver_major == oqs_python_ver_major and oqs_ver_minor == oqs_python_ver_minor
):
warnings.warn(
f"liboqs version (major, minor) {oqs_version()} differs from liboqs-python version "
f"{oqs_python_version()}",
stacklevel=2,
)


class MechanismNotSupportedError(Exception):
Expand Down Expand Up @@ -281,7 +303,9 @@ class KeyEncapsulation(ct.Structure):
("decaps_cb", ct.c_void_p),
]

def __init__(self, alg_name: str, secret_key: Union[int, bytes, None] = None) -> None:
def __init__(
self, alg_name: str, secret_key: Union[int, bytes, None] = None
) -> None:
"""
Create new KeyEncapsulation with the given algorithm.

Expand Down Expand Up @@ -435,9 +459,15 @@ def is_kem_enabled(alg_name: str) -> bool:
return native().OQS_KEM_alg_is_enabled(ct.create_string_buffer(alg_name.encode()))


_KEM_alg_ids = [native().OQS_KEM_alg_identifier(i) for i in range(native().OQS_KEM_alg_count())]
_supported_KEMs: tuple[str, ...] = tuple([i.decode() for i in _KEM_alg_ids]) # noqa: N816
_enabled_KEMs: tuple[str, ...] = tuple([i for i in _supported_KEMs if is_kem_enabled(i)]) # noqa: N816
_KEM_alg_ids = [
native().OQS_KEM_alg_identifier(i) for i in range(native().OQS_KEM_alg_count())
]
_supported_KEMs: tuple[str, ...] = tuple(
[i.decode() for i in _KEM_alg_ids]
) # noqa: N816
_enabled_KEMs: tuple[str, ...] = tuple(
[i for i in _supported_KEMs if is_kem_enabled(i)]
) # noqa: N816


def get_enabled_kem_mechanisms() -> tuple[str, ...]:
Expand Down Expand Up @@ -478,7 +508,9 @@ class Signature(ct.Structure):
("verify_cb", ct.c_void_p),
]

def __init__(self, alg_name: str, secret_key: Union[int, bytes, None] = None) -> None:
def __init__(
self, alg_name: str, secret_key: Union[int, bytes, None] = None
) -> None:
"""
Create new Signature with the given algorithm.

Expand Down Expand Up @@ -723,9 +755,13 @@ def is_sig_enabled(alg_name: str) -> bool:
return native().OQS_SIG_alg_is_enabled(ct.create_string_buffer(alg_name.encode()))


_sig_alg_ids = [native().OQS_SIG_alg_identifier(i) for i in range(native().OQS_SIG_alg_count())]
_sig_alg_ids = [
native().OQS_SIG_alg_identifier(i) for i in range(native().OQS_SIG_alg_count())
]
_supported_sigs: tuple[str, ...] = tuple([i.decode() for i in _sig_alg_ids])
_enabled_sigs: tuple[str, ...] = tuple([i for i in _supported_sigs if is_sig_enabled(i)])
_enabled_sigs: tuple[str, ...] = tuple(
[i for i in _supported_sigs if is_sig_enabled(i)]
)


def get_enabled_sig_mechanisms() -> tuple[str, ...]:
Expand Down