|
20 | 20 | ) |
21 | 21 | from http.client import IncompleteRead |
22 | 22 | from io import IOBase |
| 23 | +from oauth2client import GOOGLE_TOKEN_URI |
| 24 | +from oauth2client.client import GoogleCredentials |
23 | 25 | from rohmu.common.models import StorageOperation |
24 | 26 | from rohmu.common.statsd import StatsClient, StatsdConfig |
25 | 27 | from rohmu.errors import ( |
|
66 | 68 | import dataclasses |
67 | 69 | import datetime |
68 | 70 | import errno |
69 | | -import google.auth |
70 | | -import google_auth_httplib2 # type: ignore[import-untyped] |
| 71 | + |
| 72 | +# NOTE: this import is not needed per-se, but it's imported here first to point the |
| 73 | +# user to the most important possible missing dependency |
| 74 | +import googleapiclient # noqa: F401 |
71 | 75 | import httplib2 |
72 | 76 | import json |
73 | 77 | import logging |
|
77 | 81 | import ssl |
78 | 82 | import time |
79 | 83 |
|
| 84 | +try: |
| 85 | + from oauth2client.service_account import ServiceAccountCredentials |
| 86 | + |
| 87 | + ServiceAccountCredentials_from_dict = ServiceAccountCredentials.from_json_keyfile_dict |
| 88 | +except ImportError: |
| 89 | + from oauth2client.service_account import _ServiceAccountCredentials |
| 90 | + |
| 91 | + def ServiceAccountCredentials_from_dict( |
| 92 | + credentials: dict[str, Any], scopes: Optional[list[str]] = None |
| 93 | + ) -> GoogleCredentials: |
| 94 | + if scopes is None: |
| 95 | + scopes = [] |
| 96 | + return _ServiceAccountCredentials( |
| 97 | + service_account_id=credentials["client_id"], |
| 98 | + service_account_email=credentials["client_email"], |
| 99 | + private_key_id=credentials["private_key_id"], |
| 100 | + private_key_pkcs8_text=credentials["private_key"], |
| 101 | + scopes=scopes, |
| 102 | + ) |
| 103 | + |
| 104 | + |
80 | 105 | if TYPE_CHECKING: |
81 | | - from google.auth.credentials import Credentials |
82 | 106 | from googleapiclient._apis.storage.v1 import StorageResource |
83 | 107 |
|
84 | 108 | # Silence Google API client verbose spamming |
85 | 109 | logging.getLogger("googleapiclient.discovery_cache").setLevel(logging.ERROR) |
86 | 110 | logging.getLogger("googleapiclient").setLevel(logging.WARNING) |
| 111 | +logging.getLogger("oauth2client").setLevel(logging.WARNING) |
87 | 112 |
|
88 | 113 |
|
89 | | -def get_credentials(credential_file: Optional[TextIO] = None, credentials: Optional[dict[str, Any]] = None) -> Credentials: |
90 | | - SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] |
91 | | - |
92 | | - # File takes precedence over credentials dict |
| 114 | +def get_credentials( |
| 115 | + credential_file: Optional[TextIO] = None, credentials: Optional[dict[str, Any]] = None |
| 116 | +) -> GoogleCredentials: |
93 | 117 | if credential_file: |
94 | | - cred_data = credential_file.read() |
95 | | - credentials = json.loads(cred_data) |
| 118 | + return GoogleCredentials.from_stream(credential_file) |
96 | 119 |
|
97 | | - # project_id is the second element of the returned tuple |
98 | | - if credentials: |
99 | | - # Ensure service account credentials have required fields with defaults |
100 | | - if credentials.get("type") == "service_account": |
101 | | - credentials = _ensure_service_account_fields(credentials) |
102 | | - gcreds, _ = google.auth.load_credentials_from_dict(credentials, scopes=SCOPES) |
103 | | - else: |
104 | | - gcreds, _ = google.auth.default(scopes=SCOPES) |
105 | | - return gcreds |
106 | | - |
107 | | - |
108 | | -def _ensure_service_account_fields(credentials: dict[str, Any]) -> dict[str, Any]: |
109 | | - """Ensure service account credentials have all required fields with sensible defaults.""" |
110 | | - creds = credentials.copy() |
111 | | - |
112 | | - required_defaults = { |
113 | | - "auth_uri": "https://accounts.google.com/o/oauth2/auth", |
114 | | - "token_uri": "https://oauth2.googleapis.com/token", |
115 | | - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", |
116 | | - } |
| 120 | + if credentials and credentials["type"] == "service_account": |
| 121 | + return ServiceAccountCredentials_from_dict( |
| 122 | + credentials, |
| 123 | + scopes=["https://www.googleapis.com/auth/cloud-platform"], |
| 124 | + ) |
117 | 125 |
|
118 | | - for field, default_value in required_defaults.items(): |
119 | | - if field not in creds: |
120 | | - creds[field] = default_value |
| 126 | + if credentials and credentials["type"] == "authorized_user": |
| 127 | + return GoogleCredentials( |
| 128 | + access_token=None, |
| 129 | + client_id=credentials["client_id"], |
| 130 | + client_secret=credentials["client_secret"], |
| 131 | + refresh_token=credentials["refresh_token"], |
| 132 | + token_expiry=None, |
| 133 | + token_uri=GOOGLE_TOKEN_URI, |
| 134 | + user_agent="pghoard", |
| 135 | + ) |
121 | 136 |
|
122 | | - return creds |
| 137 | + return GoogleCredentials.get_application_default() |
123 | 138 |
|
124 | 139 |
|
125 | 140 | def base64_to_hex(b64val: Union[str, bytes]) -> str: |
@@ -245,7 +260,7 @@ def _init_google_client(self) -> StorageResource: |
245 | 260 | proxy_pass=self.proxy_info.get("pass"), |
246 | 261 | ) |
247 | 262 |
|
248 | | - http = google_auth_httplib2.AuthorizedHttp(self.google_creds, http=http) |
| 263 | + http = self.google_creds.authorize(http) |
249 | 264 |
|
250 | 265 | try: |
251 | 266 | # sometimes fails: httplib2.ServerNotFoundError: Unable to find the server at www.googleapis.com |
|
0 commit comments