diff --git a/oqs/oqs.py b/oqs/oqs.py index 3f79f58..9a24f0a 100644 --- a/oqs/oqs.py +++ b/oqs/oqs.py @@ -567,6 +567,11 @@ def get_supported_kem_mechanisms() -> tuple[str, ...]: return _supported_KEMs +# Register the OQS_SIG_supports_ctx_str function from the C library +native().OQS_SIG_supports_ctx_str.restype = ct.c_bool +native().OQS_SIG_supports_ctx_str.argtypes = [ct.c_char_p] + + class Signature(ct.Structure): """ An OQS Signature wraps native/C liboqs OQS_SIG structs. @@ -788,7 +793,7 @@ def verify_with_ctx_str( :param context: the context string. :param public_key: the signer's public key. """ - if context and not self._sig.contents.sig_with_ctx_support: + if context and not self.sig_with_ctx_support: msg = "Verifying with context string not supported" raise RuntimeError(msg) diff --git a/tests/test_sig.py b/tests/test_sig.py index b579e1a..185f6a2 100644 --- a/tests/test_sig.py +++ b/tests/test_sig.py @@ -2,7 +2,7 @@ import random import oqs -from oqs.oqs import Signature +from oqs.oqs import Signature, native # Sigs for which unit testing is disabled disabled_sig_patterns = [] @@ -44,6 +44,31 @@ def check_correctness_with_ctx_str(alg_name: str) -> None: assert sig.verify_with_ctx_str(message, signature, context, public_key) # noqa: S101 +def test_sig_with_ctx_support_detection() -> None: + """ + Test that sig_with_ctx_support matches the C API and that sign_with_ctx_str + raises on unsupported algorithms. + """ + for alg_name in oqs.get_enabled_sig_mechanisms(): + with Signature(alg_name) as sig: + # Check Python attribute matches C API + c_api_result = native().OQS_SIG_supports_ctx_str(sig.method_name) + assert bool(sig.sig_with_ctx_support) == bool(c_api_result), ( # noqa: S101 + f"sig_with_ctx_support mismatch for {alg_name}" + ) + # If not supported, sign_with_ctx_str should raise + if not sig.sig_with_ctx_support: + try: + sig.sign_with_ctx_str(b"msg", b"context") + except RuntimeError as e: + if "not supported" not in str(e): + msg = f"Unexpected exception message: {e}" + raise AssertionError(msg) from e + else: + msg = f"sign_with_ctx_str did not raise for {alg_name} without context support" + raise AssertionError(msg) + + def test_wrong_message() -> tuple[None, str]: for alg_name in oqs.get_enabled_sig_mechanisms(): if any(item in alg_name for item in disabled_sig_patterns):