|
| 1 | +# -------------------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 4 | +# |
| 5 | +# The MIT License (MIT) |
| 6 | +# |
| 7 | +# Permission is hereby granted, free of charge, to any person obtaining a copy |
| 8 | +# of this software and associated documentation files (the ""Software""), to |
| 9 | +# deal in the Software without restriction, including without limitation the |
| 10 | +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 11 | +# sell copies of the Software, and to permit persons to whom the Software is |
| 12 | +# furnished to do so, subject to the following conditions: |
| 13 | +# |
| 14 | +# The above copyright notice and this permission notice shall be included in |
| 15 | +# all copies or substantial portions of the Software. |
| 16 | +# |
| 17 | +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 18 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 19 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 20 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 21 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
| 22 | +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS |
| 23 | +# IN THE SOFTWARE. |
| 24 | +# |
| 25 | +# -------------------------------------------------------------------------- |
| 26 | +"""Shared utilities for polling continuation token serialization.""" |
| 27 | + |
| 28 | +import base64 |
| 29 | +import binascii |
| 30 | +import json |
| 31 | +from typing import Any, Dict, Mapping |
| 32 | + |
| 33 | + |
| 34 | +# Current continuation token version |
| 35 | +_CONTINUATION_TOKEN_VERSION = 1 |
| 36 | + |
| 37 | +# Error message for incompatible continuation tokens from older versions |
| 38 | +_INCOMPATIBLE_TOKEN_ERROR_MESSAGE = ( |
| 39 | + "This continuation token is not compatible with this version of azure-core. " |
| 40 | + "It may have been generated by a previous version. " |
| 41 | + "See https://aka.ms/azsdk/python/core/troubleshoot for more information." |
| 42 | +) |
| 43 | + |
| 44 | +# Headers that are needed for LRO rehydration. |
| 45 | +# We use an allowlist approach for security - only include headers we actually need. |
| 46 | +_LRO_HEADERS = frozenset( |
| 47 | + [ |
| 48 | + "operation-location", |
| 49 | + # azure-asyncoperation is included only for back compat with mgmt-core<=1.6.0 |
| 50 | + "azure-asyncoperation", |
| 51 | + "location", |
| 52 | + "content-type", |
| 53 | + "retry-after", |
| 54 | + ] |
| 55 | +) |
| 56 | + |
| 57 | + |
| 58 | +def _filter_sensitive_headers(headers: Mapping[str, str]) -> Dict[str, str]: |
| 59 | + """Filter headers to only include those needed for LRO rehydration. |
| 60 | +
|
| 61 | + Uses an allowlist approach - only headers required for polling are included. |
| 62 | +
|
| 63 | + :param headers: The headers to filter. |
| 64 | + :type headers: Mapping[str, str] |
| 65 | + :return: A new dictionary with only allowed headers. |
| 66 | + :rtype: dict[str, str] |
| 67 | + """ |
| 68 | + return {k: v for k, v in headers.items() if k.lower() in _LRO_HEADERS} |
| 69 | + |
| 70 | + |
| 71 | +def _is_pickle_format(data: bytes) -> bool: |
| 72 | + """Check if the data appears to be in pickle format. |
| 73 | +
|
| 74 | + Pickle protocol markers start with \\x80 followed by a protocol version byte (1-5). |
| 75 | +
|
| 76 | + :param data: The bytes to check. |
| 77 | + :type data: bytes |
| 78 | + :return: True if the data appears to be pickled, False otherwise. |
| 79 | + :rtype: bool |
| 80 | + """ |
| 81 | + if not data or len(data) < 2: |
| 82 | + return False |
| 83 | + # Check for pickle protocol marker (0x80) followed by protocol version 1-5 |
| 84 | + return data[0:1] == b"\x80" and 1 <= data[1] <= 5 |
| 85 | + |
| 86 | + |
| 87 | +def _decode_continuation_token(continuation_token: str) -> Dict[str, Any]: |
| 88 | + """Decode a base64-encoded JSON continuation token. |
| 89 | +
|
| 90 | + :param continuation_token: The base64-encoded continuation token. |
| 91 | + :type continuation_token: str |
| 92 | + :return: The decoded JSON data as a dictionary (the "data" field from the token). |
| 93 | + :rtype: dict |
| 94 | + :raises ValueError: If the token is invalid or in an unsupported format. |
| 95 | + """ |
| 96 | + try: |
| 97 | + decoded_bytes = base64.b64decode(continuation_token) |
| 98 | + token = json.loads(decoded_bytes.decode("utf-8")) |
| 99 | + except binascii.Error: |
| 100 | + # Invalid base64 input |
| 101 | + raise ValueError("This doesn't look like a continuation token the sdk created.") from None |
| 102 | + except (json.JSONDecodeError, UnicodeDecodeError): |
| 103 | + # Check if the data appears to be from an older version |
| 104 | + if _is_pickle_format(decoded_bytes): |
| 105 | + raise ValueError(_INCOMPATIBLE_TOKEN_ERROR_MESSAGE) from None |
| 106 | + raise ValueError("Invalid continuation token format.") from None |
| 107 | + |
| 108 | + # Validate token schema - must be a dict with a version field |
| 109 | + if not isinstance(token, dict) or "version" not in token: |
| 110 | + raise ValueError("Invalid continuation token format.") from None |
| 111 | + |
| 112 | + # For now, we only support version 1 |
| 113 | + # Future versions can add handling for older versions here if needed |
| 114 | + if token["version"] != _CONTINUATION_TOKEN_VERSION: |
| 115 | + raise ValueError(_INCOMPATIBLE_TOKEN_ERROR_MESSAGE) from None |
| 116 | + |
| 117 | + return token["data"] |
| 118 | + |
| 119 | + |
| 120 | +def _encode_continuation_token(data: Any) -> str: |
| 121 | + """Encode data as a base64-encoded JSON continuation token. |
| 122 | +
|
| 123 | + The token includes a version field for future compatibility checking. |
| 124 | +
|
| 125 | + :param data: The data to encode. Must be JSON-serializable. |
| 126 | + :type data: any |
| 127 | + :return: The base64-encoded JSON string. |
| 128 | + :rtype: str |
| 129 | + :raises TypeError: If the data is not JSON-serializable. |
| 130 | + """ |
| 131 | + token = { |
| 132 | + "version": _CONTINUATION_TOKEN_VERSION, |
| 133 | + "data": data, |
| 134 | + } |
| 135 | + try: |
| 136 | + return base64.b64encode(json.dumps(token, separators=(",", ":")).encode("utf-8")).decode("ascii") |
| 137 | + except (TypeError, ValueError) as err: |
| 138 | + raise TypeError( |
| 139 | + "Unable to generate a continuation token for this operation. Payload is not JSON-serializable." |
| 140 | + ) from err |
0 commit comments