diff --git a/framework/py/flwr/supercore/auth/__init__.py b/framework/py/flwr/supercore/auth/__init__.py new file mode 100644 index 000000000000..73bc0a369a5a --- /dev/null +++ b/framework/py/flwr/supercore/auth/__init__.py @@ -0,0 +1,65 @@ +# Copyright 2025 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Transport-agnostic authentication primitives for AppIo services.""" + + +from .appio_auth import ( + AuthDecision, + AuthDecisionEngine, + AuthDecisionFailureReason, + Authenticator, + AuthInput, + CallerIdentity, + SignedMetadataAuthInput, + TokenAuthenticator, +) +from .constant import ( + APP_TOKEN_HEADER, + APPIO_SIGNED_METADATA_METHOD_HEADER, + APPIO_SIGNED_METADATA_PLUGIN_TYPE_HEADER, + APPIO_SIGNED_METADATA_PUBLIC_KEY_HEADER, + APPIO_SIGNED_METADATA_SIGNATURE_HEADER, + APPIO_SIGNED_METADATA_TIMESTAMP_HEADER, + AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA, + AUTH_MECHANISM_TOKEN, + AUTHENTICATION_FAILED_MESSAGE, + CALLER_TYPE_APP_EXECUTOR, + CALLER_TYPE_SUPEREXEC, +) +from .policy import MethodAuthPolicy, validate_method_auth_policy_map + +__all__ = [ + "APPIO_SIGNED_METADATA_METHOD_HEADER", + "APPIO_SIGNED_METADATA_PLUGIN_TYPE_HEADER", + "APPIO_SIGNED_METADATA_PUBLIC_KEY_HEADER", + "APPIO_SIGNED_METADATA_SIGNATURE_HEADER", + "APPIO_SIGNED_METADATA_TIMESTAMP_HEADER", + "APP_TOKEN_HEADER", + "AUTHENTICATION_FAILED_MESSAGE", + "AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA", + "AUTH_MECHANISM_TOKEN", + "AuthDecision", + "AuthDecisionEngine", + "AuthDecisionFailureReason", + "AuthInput", + "Authenticator", + "CALLER_TYPE_APP_EXECUTOR", + "CALLER_TYPE_SUPEREXEC", + "CallerIdentity", + "MethodAuthPolicy", + "SignedMetadataAuthInput", + "TokenAuthenticator", + "validate_method_auth_policy_map", +] diff --git a/framework/py/flwr/supercore/auth/appio_auth.py b/framework/py/flwr/supercore/auth/appio_auth.py new file mode 100644 index 000000000000..82ed0a421284 --- /dev/null +++ b/framework/py/flwr/supercore/auth/appio_auth.py @@ -0,0 +1,280 @@ +# Copyright 2025 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Transport-agnostic authentication primitives used by AppIo adapters.""" + +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from enum import Enum +from typing import Protocol + +from .constant import ( + AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA, + AUTH_MECHANISM_TOKEN, + CALLER_TYPE_APP_EXECUTOR, +) +from .policy import MethodAuthPolicy + + +@dataclass(frozen=True) +class SignedMetadataAuthInput: + """Signed metadata payload extracted from request metadata. + + This is transport-normalized input only. Signature/timestamp verification remains in + authenticator implementations so policy logic stays mechanism-agnostic. This payload + is raw metadata and is not pre-verified at extraction time. + """ + + # Caller's public key from metadata, used for key identity + signature verify. + public_key: bytes + # Signature over the expected payload (for example, timestamp + method). + signature: bytes + # Caller-provided ISO timestamp used for freshness and replay-window checks. + timestamp_iso: str + # RPC method name bound into signature payload to prevent cross-method replay. + method: str + # Expected SuperExec plugin scope (for allowlist/policy checks). + plugin_type: str | None = None + + +@dataclass(frozen=True) +class AuthInput: + """Authentication data extracted from a transport-specific request. + + ``AuthInput`` is the single handoff object from transport adapters to the auth + layer. Keeping all optional inputs here makes it easy to add mechanisms + without changing policy or interceptor call signatures. + """ + + token: str | None = None + # True means signed-metadata auth material was supplied on the request path. + # This can be True while ``signed_metadata`` is None when extraction sees a + # partial/malformed signed-metadata payload. + signed_metadata_present: bool = False + signed_metadata: SignedMetadataAuthInput | None = None + + def __post_init__(self) -> None: + """Validate signed metadata presence invariants.""" + if self.signed_metadata is not None and not self.signed_metadata_present: + raise ValueError( + "signed_metadata_present must be True when signed_metadata is set." + ) + + +@dataclass(frozen=True) +class CallerIdentity: + """Normalized authenticated caller identity. + + This shape supports both app-executor and SuperExec callers. Fields are + intentionally optional so one identity type can represent multiple auth mechanisms. + """ + + # Auth mechanism that produced this identity (token, signed-metadata, ...). + mechanism: str + # Normalized caller category (for example, app_executor or superexec). + caller_type: str + # Authenticated run binding when applicable; None for non-run-bound callers. + run_id: int | None = None + # Stable key identifier for key-based callers; None for non-key mechanisms. + key_fingerprint: str | None = None + + +@dataclass(frozen=True) +class AuthDecision: + """Result of evaluating an ``AuthInput`` against a method policy. + + ``failure_reason`` is internal-only for tests/diagnostics. Interceptors still + map denials to canonical external responses (for example, PERMISSION_DENIED). + """ + + is_allowed: bool + caller_identity: CallerIdentity | None + failure_reason: "AuthDecisionFailureReason | None" = None + + +class AuthDecisionFailureReason(Enum): + """Internal reasons for auth denials.""" + + MISSING_AUTH_INPUT = "missing_auth_input" + INVALID_AUTH_INPUT = "invalid_auth_input" + NON_REQUIRED_MECHANISM_PRESENT = "non_required_mechanism_present" + POLICY_MISCONFIGURED = "policy_misconfigured" + + +class Authenticator(Protocol): + """Authentication primitive for one mechanism.""" + + mechanism: str + + def is_present(self, auth_input: AuthInput) -> bool: + """Return whether this mechanism's auth input is present.""" + + def authenticate(self, auth_input: AuthInput) -> CallerIdentity | None: + """Return caller identity if authentication succeeds.""" + + +class AuthDecisionEngine: + """Evaluate method policy against available authenticators. + + The engine is transport-independent and currently enforces one configured mechanism + per RPC. It decides mechanism compatibility and delegates cryptographic/token checks + to authenticators. + """ + + def __init__( + self, + authenticators: Mapping[str, Authenticator], + method_auth_policies: Mapping[str, MethodAuthPolicy], + ) -> None: + """Initialize decision engine and validate startup policy/authenticator shape. + + ``method_auth_policies`` is used for fail-fast startup validation only. + Runtime evaluation remains per-call via ``evaluate(policy, auth_input)``. + """ + self._authenticators = authenticators + # Validate at construction to fail fast on startup configuration bugs. + self._validate_policy_mechanisms(method_auth_policies) + + def _validate_policy_mechanisms( + self, method_auth_policies: Mapping[str, MethodAuthPolicy] + ) -> None: + """Fail fast if policy references unknown mechanisms.""" + invalid_policy_values: list[str] = [] + missing_by_method: dict[str, list[str]] = {} + for method, policy in method_auth_policies.items(): + if not isinstance(policy, MethodAuthPolicy): + invalid_policy_values.append(method) + continue + required_mechanism = policy.required_mechanism + if ( + required_mechanism is not None + and required_mechanism not in self._authenticators + ): + missing_by_method[method] = [required_mechanism] + if invalid_policy_values or missing_by_method: + invalid_entries = invalid_policy_values or "None" + missing_entries = missing_by_method or "None" + raise ValueError( + "Invalid method auth policies for AuthDecisionEngine.\n" + f"Entries with invalid policy objects: {invalid_entries}\n" + "Entries referencing mechanisms without authenticators: " + f"{missing_entries}" + ) + + @staticmethod + def _present_mechanisms_from_auth_input(auth_input: AuthInput) -> set[str]: + """Return present mechanisms based directly on normalized auth input.""" + present_mechanisms: set[str] = set() + if auth_input.token is not None: + present_mechanisms.add(AUTH_MECHANISM_TOKEN) + if auth_input.signed_metadata_present: + present_mechanisms.add(AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA) + return present_mechanisms + + def evaluate(self, policy: MethodAuthPolicy, auth_input: AuthInput) -> AuthDecision: + """Evaluate authentication for a single method invocation.""" + if not policy.requires_authentication: + return AuthDecision(is_allowed=True, caller_identity=None) + + required_mechanism = policy.required_mechanism + required_authenticator = ( + None + if required_mechanism is None + else self._authenticators.get(required_mechanism) + ) + failure_reason: AuthDecisionFailureReason | None = None + + if required_mechanism is None or required_authenticator is None: + # Defensive fallback for malformed policies and runtime safety. + failure_reason = AuthDecisionFailureReason.POLICY_MISCONFIGURED + else: + present_mechanisms = self._present_mechanisms_from_auth_input(auth_input) + + # Check if the required mechanism is missing. + if required_mechanism not in present_mechanisms: + failure_reason = AuthDecisionFailureReason.MISSING_AUTH_INPUT + # Check if any present mechanism is not the one required by policy. + # This is explicitly denied to keep one-mechanism-per-RPC semantics. + elif any( + mechanism != required_mechanism for mechanism in present_mechanisms + ): + failure_reason = ( + AuthDecisionFailureReason.NON_REQUIRED_MECHANISM_PRESENT + ) + # If required mechanism is present and no extra mechanism is present, + # attempt authentication with that mechanism. + else: + caller_identity = required_authenticator.authenticate(auth_input) + if caller_identity is None: + failure_reason = AuthDecisionFailureReason.INVALID_AUTH_INPUT + elif caller_identity.mechanism != required_mechanism: + # Defensive check: authenticator returned identity inconsistent + # with the mechanism required by policy. + failure_reason = AuthDecisionFailureReason.POLICY_MISCONFIGURED + else: + return AuthDecision( + is_allowed=True, + caller_identity=caller_identity, + failure_reason=None, + ) + + return AuthDecision( + is_allowed=False, + caller_identity=None, + failure_reason=failure_reason, + ) + + +class _TokenState(Protocol): + """State methods required for token authentication.""" + + def get_run_id_by_token(self, token: str) -> int | None: + """Return run_id for token or None.""" + + def verify_token(self, run_id: int, token: str) -> bool: + """Return whether token is valid for run_id.""" + + +class TokenAuthenticator: + """Token-based authenticator for AppIo callers. + + This is one concrete mechanism implementation registered into the decision engine. + Future SuperExec signed-metadata auth should follow the same pattern. + """ + + mechanism = AUTH_MECHANISM_TOKEN + + def __init__(self, state_provider: Callable[[], _TokenState]) -> None: + self._state_provider = state_provider + + def is_present(self, auth_input: AuthInput) -> bool: + """Return whether token auth input is present.""" + return auth_input.token is not None + + def authenticate(self, auth_input: AuthInput) -> CallerIdentity | None: + """Authenticate caller using AppIo token.""" + token = auth_input.token + if token is None: + return None + + state = self._state_provider() + run_id = state.get_run_id_by_token(token) + if run_id is None or not state.verify_token(run_id, token): + return None + + return CallerIdentity( + mechanism=self.mechanism, + caller_type=CALLER_TYPE_APP_EXECUTOR, + run_id=run_id, + ) diff --git a/framework/py/flwr/supercore/auth/appio_auth_test.py b/framework/py/flwr/supercore/auth/appio_auth_test.py new file mode 100644 index 000000000000..b2e90c2681fe --- /dev/null +++ b/framework/py/flwr/supercore/auth/appio_auth_test.py @@ -0,0 +1,512 @@ +# Copyright 2025 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for transport-agnostic AppIo auth primitives and policy logic.""" + +from typing import cast +from unittest import TestCase +from unittest.mock import Mock + +from flwr.supercore.auth.appio_auth import ( + AuthDecisionEngine, + AuthDecisionFailureReason, + AuthInput, + CallerIdentity, + SignedMetadataAuthInput, + TokenAuthenticator, +) +from flwr.supercore.auth.constant import ( + AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA, + AUTH_MECHANISM_TOKEN, + CALLER_TYPE_APP_EXECUTOR, + CALLER_TYPE_SUPEREXEC, +) +from flwr.supercore.auth.policy import MethodAuthPolicy, validate_method_auth_policy_map + + +class _SignedMetadataPresenceAuthenticator: + """Test authenticator using signed metadata presence only.""" + + mechanism = AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA + + def is_present(self, auth_input: AuthInput) -> bool: + """Return whether signed metadata auth input is present.""" + return auth_input.signed_metadata_present + + def authenticate(self, auth_input: AuthInput) -> CallerIdentity | None: + """Return synthetic identity when signed metadata is fully populated.""" + if auth_input.signed_metadata is None: + return None + return CallerIdentity( + mechanism=self.mechanism, + caller_type=CALLER_TYPE_SUPEREXEC, + key_fingerprint="test-fingerprint", + ) + + +class _BadTokenAuthenticator: + """Test authenticator that returns inconsistent mechanism identity.""" + + mechanism = AUTH_MECHANISM_TOKEN + + def is_present(self, auth_input: AuthInput) -> bool: + """Return whether token auth input is present.""" + return auth_input.token is not None + + def authenticate(self, auth_input: AuthInput) -> CallerIdentity | None: + """Return mismatched identity to simulate authenticator bug.""" + if auth_input.token is None: + return None + return CallerIdentity( + mechanism=AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA, + caller_type=CALLER_TYPE_APP_EXECUTOR, + run_id=1, + ) + + +class TestAuthDecisionEngine(TestCase): + """Unit tests for ``AuthDecisionEngine``.""" + + def test_no_auth_policy_always_allows(self) -> None: + """Methods with no auth policy are always allowed.""" + engine = AuthDecisionEngine(authenticators={}, method_auth_policies={}) + + decision = engine.evaluate( + policy=MethodAuthPolicy.no_auth(), + auth_input=AuthInput(token=None), + ) + + self.assertTrue(decision.is_allowed) + self.assertIsNone(decision.caller_identity) + self.assertIsNone(decision.failure_reason) + + def test_token_policy_allows_with_matching_authenticator(self) -> None: + """Token policy succeeds when token authenticator yields identity.""" + state = Mock() + state.get_run_id_by_token.return_value = 13 + state.verify_token.return_value = True + engine = AuthDecisionEngine( + authenticators={AUTH_MECHANISM_TOKEN: TokenAuthenticator(lambda: state)}, + method_auth_policies={}, + ) + + decision = engine.evaluate( + policy=MethodAuthPolicy.token_required(), + auth_input=AuthInput(token="valid-token"), + ) + + self.assertTrue(decision.is_allowed) + self.assertEqual( + decision.caller_identity, + CallerIdentity( + mechanism=AUTH_MECHANISM_TOKEN, + caller_type=CALLER_TYPE_APP_EXECUTOR, + run_id=13, + ), + ) + self.assertIsNone(decision.failure_reason) + + def test_engine_fails_fast_when_policy_references_missing_authenticator( + self, + ) -> None: + """Construction fails if any required mechanism has no authenticator.""" + with self.assertRaisesRegex( + ValueError, "Entries referencing mechanisms without authenticators" + ): + AuthDecisionEngine( + authenticators={}, + method_auth_policies={ + "/flwr.proto.ServerAppIo/GetNodes": ( + MethodAuthPolicy.token_required() + ) + }, + ) + + def test_engine_fails_fast_when_policy_value_is_wrong_type(self) -> None: + """Construction fails with actionable error for non-policy values.""" + with self.assertRaisesRegex(ValueError, "Entries with invalid policy objects"): + AuthDecisionEngine( + authenticators={}, + method_auth_policies=cast( + dict[str, MethodAuthPolicy], + {"/flwr.proto.ServerAppIo/GetNodes": "bad-policy"}, + ), + ) + + def test_token_policy_denies_when_authenticator_missing(self) -> None: + """Policy requiring an unavailable authenticator is denied.""" + engine = AuthDecisionEngine(authenticators={}, method_auth_policies={}) + + decision = engine.evaluate( + policy=MethodAuthPolicy.token_required(), + auth_input=AuthInput(token="token"), + ) + + self.assertFalse(decision.is_allowed) + self.assertIsNone(decision.caller_identity) + self.assertEqual( + decision.failure_reason, AuthDecisionFailureReason.POLICY_MISCONFIGURED + ) + + def test_token_policy_denies_when_token_missing(self) -> None: + """Token policy requires token input to be present.""" + state = Mock() + engine = AuthDecisionEngine( + authenticators={AUTH_MECHANISM_TOKEN: TokenAuthenticator(lambda: state)}, + method_auth_policies={}, + ) + + decision = engine.evaluate( + policy=MethodAuthPolicy.token_required(), + auth_input=AuthInput(token=None), + ) + + self.assertFalse(decision.is_allowed) + self.assertIsNone(decision.caller_identity) + self.assertEqual( + decision.failure_reason, AuthDecisionFailureReason.MISSING_AUTH_INPUT + ) + + def test_token_policy_denies_when_authenticator_returns_wrong_mechanism( + self, + ) -> None: + """Mismatched caller mechanism from authenticator is misconfiguration.""" + engine = AuthDecisionEngine( + authenticators={AUTH_MECHANISM_TOKEN: _BadTokenAuthenticator()}, + method_auth_policies={}, + ) + + decision = engine.evaluate( + policy=MethodAuthPolicy.token_required(), + auth_input=AuthInput(token="token"), + ) + + self.assertFalse(decision.is_allowed) + self.assertIsNone(decision.caller_identity) + self.assertEqual( + decision.failure_reason, AuthDecisionFailureReason.POLICY_MISCONFIGURED + ) + + def test_token_required_denies_when_extra_signed_metadata_is_present(self) -> None: + """Required token auth denies when another mechanism is simultaneously + present.""" + state = Mock() + state.get_run_id_by_token.return_value = 1 + state.verify_token.return_value = True + engine = AuthDecisionEngine( + authenticators={ + AUTH_MECHANISM_TOKEN: TokenAuthenticator(lambda: state), + AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA: ( + _SignedMetadataPresenceAuthenticator() + ), + }, + method_auth_policies={}, + ) + + decision = engine.evaluate( + policy=MethodAuthPolicy.token_required(), + auth_input=AuthInput( + token="valid-token", + signed_metadata=SignedMetadataAuthInput( + public_key=b"pk", + signature=b"sig", + timestamp_iso="2026-03-09T10:00:00", + method="/flwr.proto.ServerAppIo/GetNodes", + ), + signed_metadata_present=True, + ), + ) + + self.assertFalse(decision.is_allowed) + self.assertIsNone(decision.caller_identity) + self.assertEqual( + decision.failure_reason, + AuthDecisionFailureReason.NON_REQUIRED_MECHANISM_PRESENT, + ) + + def test_signed_metadata_required_denies_when_extra_token_is_present(self) -> None: + """Required signed-metadata auth denies when token input is also present.""" + state = Mock() + engine = AuthDecisionEngine( + authenticators={ + AUTH_MECHANISM_TOKEN: TokenAuthenticator(lambda: state), + AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA: ( + _SignedMetadataPresenceAuthenticator() + ), + }, + method_auth_policies={}, + ) + + decision = engine.evaluate( + policy=MethodAuthPolicy.signed_metadata_required(), + auth_input=AuthInput( + token="some-token", + signed_metadata=SignedMetadataAuthInput( + public_key=b"pk", + signature=b"sig", + timestamp_iso="2026-03-09T10:00:00", + method="/flwr.proto.ServerAppIo/GetNodes", + ), + signed_metadata_present=True, + ), + ) + + self.assertFalse(decision.is_allowed) + self.assertIsNone(decision.caller_identity) + self.assertEqual( + decision.failure_reason, + AuthDecisionFailureReason.NON_REQUIRED_MECHANISM_PRESENT, + ) + + def test_token_required_denies_extra_signed_metadata_without_authenticator( + self, + ) -> None: + """Extra signed-metadata input is denied even without its authenticator.""" + state = Mock() + state.get_run_id_by_token.return_value = 1 + state.verify_token.return_value = True + engine = AuthDecisionEngine( + authenticators={AUTH_MECHANISM_TOKEN: TokenAuthenticator(lambda: state)}, + method_auth_policies={}, + ) + + decision = engine.evaluate( + policy=MethodAuthPolicy.token_required(), + auth_input=AuthInput( + token="valid-token", + signed_metadata=SignedMetadataAuthInput( + public_key=b"pk", + signature=b"sig", + timestamp_iso="2026-03-09T10:00:00", + method="/flwr.proto.ServerAppIo/GetNodes", + ), + signed_metadata_present=True, + ), + ) + + self.assertFalse(decision.is_allowed) + self.assertIsNone(decision.caller_identity) + self.assertEqual( + decision.failure_reason, + AuthDecisionFailureReason.NON_REQUIRED_MECHANISM_PRESENT, + ) + state.get_run_id_by_token.assert_not_called() + state.verify_token.assert_not_called() + + def test_signed_metadata_required_denies_extra_token_without_authenticator( + self, + ) -> None: + """Extra token input is denied even without a token authenticator.""" + engine = AuthDecisionEngine( + authenticators={ + AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA: ( + _SignedMetadataPresenceAuthenticator() + ) + }, + method_auth_policies={}, + ) + + decision = engine.evaluate( + policy=MethodAuthPolicy.signed_metadata_required(), + auth_input=AuthInput( + token="token", + signed_metadata=SignedMetadataAuthInput( + public_key=b"pk", + signature=b"sig", + timestamp_iso="2026-03-09T10:00:00", + method="/flwr.proto.ServerAppIo/GetNodes", + ), + signed_metadata_present=True, + ), + ) + + self.assertFalse(decision.is_allowed) + self.assertIsNone(decision.caller_identity) + self.assertEqual( + decision.failure_reason, + AuthDecisionFailureReason.NON_REQUIRED_MECHANISM_PRESENT, + ) + + def test_signed_metadata_required_denies_when_input_is_malformed(self) -> None: + """Malformed signed metadata is invalid input, not missing input.""" + engine = AuthDecisionEngine( + authenticators={ + AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA: ( + _SignedMetadataPresenceAuthenticator() + ) + }, + method_auth_policies={}, + ) + + decision = engine.evaluate( + policy=MethodAuthPolicy.signed_metadata_required(), + auth_input=AuthInput(signed_metadata_present=True, signed_metadata=None), + ) + + self.assertFalse(decision.is_allowed) + self.assertIsNone(decision.caller_identity) + self.assertEqual( + decision.failure_reason, + AuthDecisionFailureReason.INVALID_AUTH_INPUT, + ) + + +class TestTokenAuthenticator(TestCase): + """Unit tests for ``TokenAuthenticator``.""" + + def test_missing_token_is_denied(self) -> None: + """No token in input returns no identity.""" + state = Mock() + authenticator = TokenAuthenticator(lambda: state) + + self.assertFalse(authenticator.is_present(AuthInput(token=None))) + caller_identity = authenticator.authenticate(AuthInput(token=None)) + + self.assertIsNone(caller_identity) + state.get_run_id_by_token.assert_not_called() + + def test_invalid_token_is_denied(self) -> None: + """Unknown token returns no identity.""" + state = Mock() + state.get_run_id_by_token.return_value = None + authenticator = TokenAuthenticator(lambda: state) + + self.assertTrue(authenticator.is_present(AuthInput(token="invalid-token"))) + caller_identity = authenticator.authenticate(AuthInput(token="invalid-token")) + + self.assertIsNone(caller_identity) + state.get_run_id_by_token.assert_called_once_with("invalid-token") + state.verify_token.assert_not_called() + + def test_valid_token_returns_identity(self) -> None: + """Valid token returns a normalized caller identity.""" + state = Mock() + state.get_run_id_by_token.return_value = 42 + state.verify_token.return_value = True + authenticator = TokenAuthenticator(lambda: state) + + caller_identity = authenticator.authenticate(AuthInput(token="valid-token")) + + self.assertEqual( + caller_identity, + CallerIdentity( + mechanism=AUTH_MECHANISM_TOKEN, + caller_type=CALLER_TYPE_APP_EXECUTOR, + run_id=42, + ), + ) + state.get_run_id_by_token.assert_called_once_with("valid-token") + state.verify_token.assert_called_once_with(42, "valid-token") + + +class TestSignedMetadataPresence(TestCase): + """Unit tests for signed metadata presence detection hooks.""" + + def test_signed_metadata_presence_detected(self) -> None: + """Signed metadata input should be detectable by a dedicated authenticator.""" + authenticator = _SignedMetadataPresenceAuthenticator() + + self.assertTrue( + authenticator.is_present( + AuthInput( + signed_metadata_present=True, + signed_metadata=SignedMetadataAuthInput( + public_key=b"pk", + signature=b"sig", + timestamp_iso="2026-03-09T10:00:00", + method="/flwr.proto.ServerAppIo/GetNodes", + ), + ) + ) + ) + + def test_signed_metadata_absence_detected(self) -> None: + """Missing signed metadata should not appear as present.""" + authenticator = _SignedMetadataPresenceAuthenticator() + + self.assertFalse( + authenticator.is_present(AuthInput(signed_metadata_present=False)) + ) + + def test_signed_metadata_partial_payload_detected_as_present(self) -> None: + """Presence flag distinguishes malformed input from a missing mechanism.""" + authenticator = _SignedMetadataPresenceAuthenticator() + + self.assertTrue( + authenticator.is_present(AuthInput(signed_metadata_present=True)) + ) + + +class TestAuthInputInvariant(TestCase): + """Unit tests for ``AuthInput`` invariants.""" + + def test_signed_metadata_requires_presence_flag(self) -> None: + """Setting signed metadata without presence flag should fail.""" + with self.assertRaisesRegex(ValueError, "signed_metadata_present must be True"): + AuthInput( + signed_metadata=SignedMetadataAuthInput( + public_key=b"pk", + signature=b"sig", + timestamp_iso="2026-03-09T10:00:00", + method="/flwr.proto.ServerAppIo/GetNodes", + ) + ) + + +class TestMethodAuthPolicyValidation(TestCase): + """Unit tests for method policy table validation.""" + + def test_method_auth_policy_rejects_run_match_without_mechanism(self) -> None: + """Run-id-match cannot be enabled when no mechanism is required.""" + with self.assertRaisesRegex( + ValueError, + "requires_run_id_match=True requires a non-None required_mechanism.", + ): + MethodAuthPolicy(required_mechanism=None, requires_run_id_match=True) + + def test_validate_method_auth_policy_map_accepts_matching_table(self) -> None: + """Validation passes when table exactly matches service RPC names.""" + table = { + "/flwr.proto.TestService/Foo": MethodAuthPolicy.no_auth(), + "/flwr.proto.TestService/Bar": MethodAuthPolicy.token_required(), + } + + validate_method_auth_policy_map( + service_name="TestService", + package_name="flwr.proto", + rpc_method_names=("Foo", "Bar"), + method_auth_policy=table, + table_name="TEST_POLICY", + table_location=__file__, + ) + + def test_validate_method_auth_policy_map_rejects_invalid_table(self) -> None: + """Validation fails when policy table misses or mis-types entries.""" + with self.assertRaisesRegex( + ValueError, "Invalid AppIo method auth policy table." + ): + validate_method_auth_policy_map( + service_name="TestService", + package_name="flwr.proto", + rpc_method_names=("Foo",), + method_auth_policy=cast( + dict[str, MethodAuthPolicy], + { + "/flwr.proto.TestService/Bar": MethodAuthPolicy.no_auth(), + "/flwr.proto.TestService/Foo": "bad-policy", + }, + ), + table_name="TEST_POLICY", + table_location=__file__, + ) diff --git a/framework/py/flwr/supercore/auth/constant.py b/framework/py/flwr/supercore/auth/constant.py new file mode 100644 index 000000000000..e9110de1a7e8 --- /dev/null +++ b/framework/py/flwr/supercore/auth/constant.py @@ -0,0 +1,31 @@ +# Copyright 2025 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Auth-related constants shared by AppIo auth abstractions.""" + +from typing import Final + +AUTH_MECHANISM_TOKEN: Final[str] = "token" +AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA: Final[str] = "superexec-signed-metadata" +CALLER_TYPE_APP_EXECUTOR: Final[str] = "app_executor" +CALLER_TYPE_SUPEREXEC: Final[str] = "superexec" +AUTHENTICATION_FAILED_MESSAGE: Final[str] = "Authentication failed." + +# gRPC metadata keys for AppIo auth input extraction (token-based and signed-metadata). +APP_TOKEN_HEADER: Final[str] = "flwr-app-token" +APPIO_SIGNED_METADATA_PUBLIC_KEY_HEADER: Final[str] = "flwr-superexec-public-key-bin" +APPIO_SIGNED_METADATA_SIGNATURE_HEADER: Final[str] = "flwr-superexec-signature-bin" +APPIO_SIGNED_METADATA_TIMESTAMP_HEADER: Final[str] = "flwr-superexec-timestamp" +APPIO_SIGNED_METADATA_METHOD_HEADER: Final[str] = "flwr-superexec-method" +APPIO_SIGNED_METADATA_PLUGIN_TYPE_HEADER: Final[str] = "flwr-superexec-plugin-type" diff --git a/framework/py/flwr/supercore/auth/policy.py b/framework/py/flwr/supercore/auth/policy.py new file mode 100644 index 000000000000..64440cac2e66 --- /dev/null +++ b/framework/py/flwr/supercore/auth/policy.py @@ -0,0 +1,109 @@ +# Copyright 2025 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Policy types and validation helpers for AppIo authentication.""" + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass + +from .constant import AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA, AUTH_MECHANISM_TOKEN + + +@dataclass(frozen=True) +class MethodAuthPolicy: + """Authentication policy for a single RPC method. + + Policy is intentionally separate from mechanism implementations. This keeps RPC + access decisions declarative and lets mechanisms evolve independently. + + This policy currently supports at most one mechanism per RPC. + `requires_run_id_match` is enforced in transport/handler layers that + have request context (for example, gRPC interceptors/servicers), not by + `AuthDecisionEngine`. + """ + + required_mechanism: str | None = None + requires_run_id_match: bool = False + + def __post_init__(self) -> None: + """Validate cross-field invariants.""" + if self.required_mechanism is None and self.requires_run_id_match: + raise ValueError( + "requires_run_id_match=True requires a non-None required_mechanism." + ) + + @property + def requires_authentication(self) -> bool: + """Return whether authentication is required for this method.""" + return self.required_mechanism is not None + + @classmethod + def no_auth(cls) -> "MethodAuthPolicy": + """Create a policy for methods that do not require authentication.""" + return cls(required_mechanism=None) + + @classmethod + def token_required( + cls, *, requires_run_id_match: bool = False + ) -> "MethodAuthPolicy": + """Create a policy for methods requiring token authentication.""" + return cls( + required_mechanism=AUTH_MECHANISM_TOKEN, + requires_run_id_match=requires_run_id_match, + ) + + @classmethod + def signed_metadata_required( + cls, *, requires_run_id_match: bool = False + ) -> "MethodAuthPolicy": + """Create a policy for methods requiring signed metadata authentication.""" + return cls( + required_mechanism=AUTH_MECHANISM_SUPEREXEC_SIGNED_METADATA, + requires_run_id_match=requires_run_id_match, + ) + + +# Keep explicit keyword arguments for clearer startup error messages. +def validate_method_auth_policy_map( # pylint: disable=too-many-arguments + *, + service_name: str, + package_name: str, + rpc_method_names: Sequence[str], + method_auth_policy: Mapping[str, MethodAuthPolicy], + table_name: str, + table_location: str, +) -> None: + """Validate that method auth policy table exactly matches service RPCs.""" + service_fqn = f"{package_name}.{service_name}" + expected = {f"/{service_fqn}/{rpc_name}" for rpc_name in rpc_method_names} + configured = set(method_auth_policy) + missing = sorted(expected - configured) + extra = sorted(configured - expected) + invalid_policy_values = sorted( + method_name + for method_name, policy in method_auth_policy.items() + if not isinstance(policy, MethodAuthPolicy) + ) + if missing or extra or invalid_policy_values: + raise ValueError( + "Invalid AppIo method auth policy table.\n" + f"Table: {table_name}\n" + f"Location: {table_location}\n" + f"Service: {service_fqn}\n" + f"Missing RPC entries: {missing or 'None'}\n" + f"Unexpected RPC entries: {extra or 'None'}\n" + f"Entries with invalid policy objects: {invalid_policy_values or 'None'}\n" + "How to fix: update the policy table to include exactly one " + "`MethodAuthPolicy` entry for each RPC exposed by the service." + )