|
| 1 | +from pyeudiw.satosa.exceptions import DiscoveryFailedError |
| 2 | +from pyeudiw.federation.statements import EntityStatement, get_entity_configurations |
| 3 | +from pyeudiw.federation.trust_chain_builder import TrustChainBuilder |
| 4 | +from pyeudiw.federation.exceptions import ProtocolMetadataNotFound |
| 5 | +from copy import deepcopy |
| 6 | +from cryptojwt.jwk.ec import ECKey |
| 7 | +from cryptojwt.jwk.jwk import key_from_jwk_dict |
| 8 | +from cryptojwt.jwk.rsa import RSAKey |
| 9 | +from datetime import datetime |
| 10 | +from pyeudiw.federation.policy import combine |
| 11 | + |
| 12 | + |
| 13 | +def get_backend_trust_chain(self) -> list[str]: |
| 14 | + """ |
| 15 | + Get the backend trust chain. In case something raises an Exception (e.g. faulty storage), logs a warning message |
| 16 | + and returns an empty list. |
| 17 | +
|
| 18 | + :return: The trust chain |
| 19 | + :rtype: list |
| 20 | + """ |
| 21 | + |
| 22 | + try: |
| 23 | + trust_evaluation_helper = self.build_trust_chain_for_entity_id( |
| 24 | + storage=self.db_engine, |
| 25 | + entity_id=self.client_id, |
| 26 | + entity_configuration=self.entity_configuration, |
| 27 | + httpc_params=self.httpc_params, |
| 28 | + ) |
| 29 | + |
| 30 | + self.db_engine.add_or_update_trust_attestation( |
| 31 | + entity_id=self.client_id, |
| 32 | + attestation=trust_evaluation_helper.trust_chain, |
| 33 | + exp=trust_evaluation_helper.exp, |
| 34 | + ) |
| 35 | + return trust_evaluation_helper.trust_chain |
| 36 | + |
| 37 | + except (DiscoveryFailedError, EntryNotFound, Exception) as e: |
| 38 | + message = ( |
| 39 | + f"Error while building trust chain for client with id: {self.client_id}. " |
| 40 | + f"{e.__class__.__name__}: {e}" |
| 41 | + ) |
| 42 | + self._log_warning("Trust Chain", message) |
| 43 | + |
| 44 | + return [] |
| 45 | + |
| 46 | +@property |
| 47 | +def default_federation_private_jwk(self) -> dict: |
| 48 | + """Returns the default federation private jwk.""" |
| 49 | + return tuple(self.federations_jwks_by_kids.values())[0] |
| 50 | + |
| 51 | +# era class FederationTrustModel(TrustEvaluator): |
| 52 | + |
| 53 | +def get_public_keys(self, issuer): |
| 54 | + public_keys = [JWK(i).as_public_dict() for i in self.federation_jwks] |
| 55 | + |
| 56 | + return public_keys |
| 57 | + |
| 58 | +def get_verified_key( |
| 59 | + self, issuer: str, token_header: dict |
| 60 | +) -> ECKey | RSAKey | dict: |
| 61 | + # (1) verifica trust chain |
| 62 | + kid: str = token_header.get("kid", None) |
| 63 | + if not kid: |
| 64 | + raise ValueError("missing claim [kid] in token header") |
| 65 | + trust_chain: list[str] = token_header.get("trust_chain", None) |
| 66 | + if not trust_chain: |
| 67 | + raise ValueError("missing trust chain in federation token") |
| 68 | + if not isinstance(trust_chain, list): |
| 69 | + raise ValueError * ("invalid format of header claim [trust_claim]") |
| 70 | + # TODO: check whick exceptions this might raise |
| 71 | + self._verify_trust_chain(trust_chain) |
| 72 | + |
| 73 | + # (2) metadata parsing ed estrazione Jwk set |
| 74 | + # TODO: wrap in something that implements VciJwksSource |
| 75 | + # apply policy of traust anchor only? |
| 76 | + issuer_entity_configuration = trust_chain[0] |
| 77 | + anchor_entity_configuration = trust_chain[-1] |
| 78 | + issuer_payload: dict = decode_jwt_payload(issuer_entity_configuration) |
| 79 | + anchor_payload = decode_jwt_payload(anchor_entity_configuration) |
| 80 | + trust_anchor_policy = anchor_payload.get("metadata_policy", {}) |
| 81 | + final_issuer_metadata = self.metadata_policy_resolver.apply_policy( |
| 82 | + issuer_payload, trust_anchor_policy |
| 83 | + ) |
| 84 | + metadata: dict = final_issuer_metadata.get("metadata", None) |
| 85 | + if not metadata: |
| 86 | + raise ValueError( |
| 87 | + "missing or invalid claim [metadata] in entity configuration" |
| 88 | + ) |
| 89 | + issuer_metadata: dict = metadata.get(_ISSUER_METADATA_TYPE, None) |
| 90 | + if not issuer_metadata: |
| 91 | + raise ValueError( |
| 92 | + f"missing or invalid claim [metadata.{_ISSUER_METADATA_TYPE}] in entity configuration" |
| 93 | + ) |
| 94 | + issuer_keys: list[dict] = issuer_metadata.get("jwks", {}).get("keys", []) |
| 95 | + if not issuer_keys: |
| 96 | + raise ValueError( |
| 97 | + f"missing or invalid claim [metadata.{_ISSUER_METADATA_TYPE}.jwks.keys] in entity configuration" |
| 98 | + ) |
| 99 | + # check issuer = entity_id |
| 100 | + if issuer != (obt_iss := final_issuer_metadata.get("iss", "")): |
| 101 | + raise ValueError( |
| 102 | + f"invalid issuer metadata: expected '{issuer}', obtained '{obt_iss}'" |
| 103 | + ) |
| 104 | + |
| 105 | + # (3) dato il set completo, fa il match per kid tra l'header e il jwk set |
| 106 | + found_jwks: list[dict] = [] |
| 107 | + for key in issuer_keys: |
| 108 | + obt_kid: str = key.get("kid", "") |
| 109 | + if kid == obt_kid: |
| 110 | + found_jwks.append(key) |
| 111 | + if len(found_jwks) != 1: |
| 112 | + raise ValueError( |
| 113 | + f"unable to uniquely identify a key with kid {kid} in appropriate section of issuer entity configuration" |
| 114 | + ) |
| 115 | + try: |
| 116 | + return key_from_jwk_dict(**found_jwks[0]) |
| 117 | + except Exception as e: |
| 118 | + raise ValueError(f"unable to parse issuer jwk: {e}") |
| 119 | + |
| 120 | +def init_trust_resources(self) -> None: |
| 121 | + """ |
| 122 | + Initializes the trust resources. |
| 123 | + """ |
| 124 | + |
| 125 | + # private keys by kid |
| 126 | + self.federations_jwks_by_kids = { |
| 127 | + i["kid"]: i |
| 128 | + for i in self.config["federation_jwks"] |
| 129 | + } |
| 130 | + # dumps public jwks |
| 131 | + self.federation_public_jwks = [ |
| 132 | + key_from_jwk_dict(i).serialize() |
| 133 | + for i in self.config["federation_jwks"] |
| 134 | + ] |
| 135 | + # we close the connection in this constructor since it must be fork safe and |
| 136 | + # get reinitialized later on, within each fork |
| 137 | + self.update_trust_anchors() |
| 138 | + |
| 139 | + try: |
| 140 | + self.get_backend_trust_chain() |
| 141 | + except Exception as e: |
| 142 | + self._log_critical( |
| 143 | + "Backend Trust", f"Cannot fetch the trust anchor configuration: {e}" |
| 144 | + ) |
| 145 | + |
| 146 | + self.db_engine.close() |
| 147 | + self._db_engine = None |
| 148 | + |
| 149 | +def update_trust_anchors(self): |
| 150 | + """ |
| 151 | + Updates the trust anchors of current instance. |
| 152 | + """ |
| 153 | + |
| 154 | + tas = self.config["trust_anchors"] |
| 155 | + self._log_info("Trust Anchors updates", f"Trying to update: {tas}") |
| 156 | + |
| 157 | + for ta in tas: |
| 158 | + try: |
| 159 | + self.update_trust_anchors_ecs( |
| 160 | + db=self.db_engine, |
| 161 | + trust_anchors=[ta], |
| 162 | + httpc_params=self.config["httpc_params"], |
| 163 | + ) |
| 164 | + except Exception as e: |
| 165 | + self._log_warning("Trust Anchor updates", f"{ta} update failed: {e}") |
| 166 | + |
| 167 | + self._log_info("Trust Anchor updates", f"{ta} updated") |
| 168 | + |
| 169 | +def _update_chain( |
| 170 | + self, |
| 171 | + entity_id: str | None = None, |
| 172 | + exp: datetime | None = None, |
| 173 | + trust_chain: list | None = None, |
| 174 | +): |
| 175 | + if entity_id is not None: |
| 176 | + self.entity_id = entity_id |
| 177 | + |
| 178 | + if exp is not None: |
| 179 | + self.exp = exp |
| 180 | + |
| 181 | + if trust_chain is not None: |
| 182 | + self.trust_chain = trust_chain |
| 183 | + |
| 184 | +def _handle_federation_chain(self, trust_chain): |
| 185 | + _first_statement = decode_jwt_payload(trust_chain[-1]) |
| 186 | + trust_anchor_eid = self.trust_anchor or _first_statement.get("iss", None) |
| 187 | + |
| 188 | + if not trust_anchor_eid: |
| 189 | + raise UnknownTrustAnchor( |
| 190 | + "Unknown Trust Anchor: can't find 'iss' in the " |
| 191 | + f"first entity statement: {_first_statement} " |
| 192 | + ) |
| 193 | + |
| 194 | + try: |
| 195 | + trust_anchor = self.storage.get_trust_anchor(trust_anchor_eid) |
| 196 | + except EntryNotFound: |
| 197 | + raise UnknownTrustAnchor( |
| 198 | + f"Unknown Trust Anchor: '{trust_anchor_eid}' is not " |
| 199 | + "a recognizable Trust Anchor." |
| 200 | + ) |
| 201 | + |
| 202 | + decoded_ec = decode_jwt_payload( |
| 203 | + trust_anchor["federation"]["entity_configuration"] |
| 204 | + ) |
| 205 | + jwks = decoded_ec.get("jwks", {}).get("keys", []) |
| 206 | + |
| 207 | + if not jwks: |
| 208 | + raise MissingProtocolSpecificJwks(f"Cannot find any jwks in {decoded_ec}") |
| 209 | + |
| 210 | + tc = StaticTrustChainValidator(self.trust_chain, jwks, self.httpc_params) |
| 211 | + self._update_chain(entity_id=tc.entity_id, exp=tc.exp) |
| 212 | + |
| 213 | + _is_valid = False |
| 214 | + |
| 215 | + try: |
| 216 | + _is_valid = tc.validate() |
| 217 | + except TimeValidationError: |
| 218 | + logger.warn(f"Trust Chain {tc.entity_id} is expired") |
| 219 | + except Exception as e: |
| 220 | + logger.warn( |
| 221 | + f"Cannot validate Trust Chain {tc.entity_id} for the following reason: {e}" |
| 222 | + ) |
| 223 | + |
| 224 | + db_chain = None |
| 225 | + |
| 226 | + if not _is_valid: |
| 227 | + try: |
| 228 | + db_chain = self.storage.get_trust_attestation(self.entity_id)[ |
| 229 | + "federation" |
| 230 | + ]["chain"] |
| 231 | + if StaticTrustChainValidator( |
| 232 | + db_chain, jwks, self.httpc_params |
| 233 | + ).is_valid: |
| 234 | + self.is_trusted = True |
| 235 | + return self.is_trusted |
| 236 | + |
| 237 | + except (EntryNotFound, Exception): |
| 238 | + pass |
| 239 | + |
| 240 | + _is_valid = tc.update() |
| 241 | + |
| 242 | + self._update_chain(trust_chain=tc.trust_chain, exp=tc.exp) |
| 243 | + |
| 244 | + # the good trust chain is then stored |
| 245 | + self.storage.add_or_update_trust_attestation( |
| 246 | + entity_id=self.entity_id, |
| 247 | + attestation=tc.trust_chain, |
| 248 | + exp=datetime.fromtimestamp(tc.exp), |
| 249 | + ) |
| 250 | + |
| 251 | + self.is_trusted = _is_valid |
| 252 | + return _is_valid |
| 253 | + |
| 254 | +def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: |
| 255 | + policy_acc = {"metadata": {}, "metadata_policy": {}} |
| 256 | + |
| 257 | + for policy in policies: |
| 258 | + policy_acc = combine(policy, policy_acc) |
| 259 | + |
| 260 | + self.final_metadata = decode_jwt_payload(self.trust_chain[0]) |
| 261 | + |
| 262 | + try: |
| 263 | + # TODO: there are some cases where the jwks are taken from a uri ... |
| 264 | + selected_metadata = { |
| 265 | + "metadata": self.final_metadata["metadata"], |
| 266 | + "metadata_policy": {}, |
| 267 | + } |
| 268 | + |
| 269 | + self.final_metadata = TrustChainPolicy().apply_policy( |
| 270 | + selected_metadata, policy_acc |
| 271 | + ) |
| 272 | + |
| 273 | + return self.final_metadata["metadata"][metadata_type] |
| 274 | + except KeyError: |
| 275 | + raise ProtocolMetadataNotFound( |
| 276 | + f"{metadata_type} not found in the final metadata:" |
| 277 | + f" {self.final_metadata['metadata']}" |
| 278 | + ) |
| 279 | + |
| 280 | +def get_trusted_jwks( |
| 281 | + self, metadata_type: str, policies: list[dict] = [] |
| 282 | +) -> list[dict]: |
| 283 | + return ( |
| 284 | + self.get_final_metadata(metadata_type=metadata_type, policies=policies) |
| 285 | + .get("jwks", {}) |
| 286 | + .get("keys", []) |
| 287 | + ) |
| 288 | + |
| 289 | +def discovery( |
| 290 | + self, entity_id: str, entity_configuration: EntityStatement | None = None |
| 291 | +): |
| 292 | + """ |
| 293 | + Updates fields ``trust_chain`` and ``exp`` based on the discovery process. |
| 294 | +
|
| 295 | + :raises: DiscoveryFailedError: raises an error if the discovery fails. |
| 296 | + """ |
| 297 | + trust_anchor_eid = self.trust_anchor |
| 298 | + _ta_ec = self.storage.get_trust_anchor(entity_id=trust_anchor_eid) |
| 299 | + ta_ec = _ta_ec["federation"]["entity_configuration"] |
| 300 | + |
| 301 | + tcbuilder = TrustChainBuilder( |
| 302 | + subject=entity_id, |
| 303 | + trust_anchor=trust_anchor_eid, |
| 304 | + trust_anchor_configuration=ta_ec, |
| 305 | + subject_configuration=entity_configuration, |
| 306 | + httpc_params=self.httpc_params, |
| 307 | + ) |
| 308 | + |
| 309 | + self._update_chain(trust_chain=tcbuilder.get_trust_chain(), exp=tcbuilder.exp) |
| 310 | + is_good = tcbuilder.is_valid |
| 311 | + if not is_good: |
| 312 | + raise DiscoveryFailedError( |
| 313 | + f"Discovery failed for entity {entity_id} with configuration {entity_configuration}" |
| 314 | + ) |
| 315 | + |
| 316 | +def build_trust_chain_for_entity_id(self, entity_id: str): |
| 317 | + """ |
| 318 | + Builds a ``TrustEvaluationHelper`` and returns it if the trust chain is valid. |
| 319 | + In case the trust chain is invalid, tries to validate it in discovery before returning it. |
| 320 | +
|
| 321 | + :return: The svg data for html, base64 encoded |
| 322 | + :rtype: str |
| 323 | + """ |
| 324 | + db_chain: list = self.storage.get_trust_attestation(entity_id) |
| 325 | + |
| 326 | + if len(db_chain) == 0: |
| 327 | + db_chain = self.discovery(self.entity_id) |
| 328 | + else: |
| 329 | + self.is_valid = self._handle_federation_chain() |
| 330 | + return self.is_valid |
| 331 | + |
| 332 | + return False |
| 333 | + |
| 334 | +def update_trust_anchors_ecs(self, trust_anchors: list[str], db: DBEngine) -> None: |
| 335 | + """ |
| 336 | + Update the trust anchors entity configurations. |
| 337 | +
|
| 338 | + :param trust_anchors: The trust anchors |
| 339 | + :type trust_anchors: list |
| 340 | + :param db: The database engine |
| 341 | + :type db: DBEngine |
| 342 | + :param httpc_params: The HTTP client parameters |
| 343 | + :type httpc_params: dict |
| 344 | + """ |
| 345 | + |
| 346 | + ta_ecs = get_entity_configurations( |
| 347 | + trust_anchors, httpc_params=self.httpc_params |
| 348 | + ) |
| 349 | + |
| 350 | + for jwt in ta_ecs: |
| 351 | + if isinstance(jwt, bytes): |
| 352 | + jwt = jwt.decode() |
| 353 | + |
| 354 | + ec = EntityStatement(jwt, httpc_params=self.httpc_params) |
| 355 | + if not ec.validate_by_itself(): |
| 356 | + logger.warning( |
| 357 | + f"The trust anchor failed the validation of its EntityConfiguration {ec}" |
| 358 | + ) |
| 359 | + |
| 360 | + db.add_trust_anchor( |
| 361 | + entity_id=ec.sub, entity_configuration=ec.jwt, exp=ec.exp |
| 362 | + ) |
0 commit comments