|
20 | 20 | ) |
21 | 21 | from http.client import IncompleteRead |
22 | 22 | from io import IOBase |
23 | | -from oauth2client import GOOGLE_TOKEN_URI |
24 | | -from oauth2client.client import GoogleCredentials |
25 | 23 | from rohmu.common.models import StorageOperation |
26 | 24 | from rohmu.common.statsd import StatsClient, StatsdConfig |
27 | 25 | from rohmu.errors import ( |
|
68 | 66 | import dataclasses |
69 | 67 | import datetime |
70 | 68 | import errno |
71 | | - |
72 | | -# NOTE: this import is not needed per-se, but it's imported here first to point the |
73 | | -# user to the most important possible missing dependency |
74 | | -import googleapiclient # noqa: F401 |
| 69 | +import google.auth |
| 70 | +import google_auth_httplib2 # type: ignore[import-untyped] |
75 | 71 | import httplib2 |
76 | 72 | import json |
77 | 73 | import logging |
|
81 | 77 | import ssl |
82 | 78 | import time |
83 | 79 |
|
84 | | -try: |
85 | | - from oauth2client.service_account import ServiceAccountCredentials |
86 | | - |
87 | | - ServiceAccountCredentials_from_dict = ServiceAccountCredentials.from_json_keyfile_dict |
88 | | -except ImportError: |
89 | | - from oauth2client.service_account import _ServiceAccountCredentials |
90 | | - |
91 | | - def ServiceAccountCredentials_from_dict( |
92 | | - credentials: dict[str, Any], scopes: Optional[list[str]] = None |
93 | | - ) -> GoogleCredentials: |
94 | | - if scopes is None: |
95 | | - scopes = [] |
96 | | - return _ServiceAccountCredentials( |
97 | | - service_account_id=credentials["client_id"], |
98 | | - service_account_email=credentials["client_email"], |
99 | | - private_key_id=credentials["private_key_id"], |
100 | | - private_key_pkcs8_text=credentials["private_key"], |
101 | | - scopes=scopes, |
102 | | - ) |
103 | | - |
104 | | - |
105 | 80 | if TYPE_CHECKING: |
| 81 | + from google.auth.credentials import Credentials |
106 | 82 | from googleapiclient._apis.storage.v1 import StorageResource |
107 | 83 |
|
108 | 84 | # Silence Google API client verbose spamming |
109 | 85 | logging.getLogger("googleapiclient.discovery_cache").setLevel(logging.ERROR) |
110 | 86 | logging.getLogger("googleapiclient").setLevel(logging.WARNING) |
111 | | -logging.getLogger("oauth2client").setLevel(logging.WARNING) |
112 | 87 |
|
113 | 88 |
|
114 | | -def get_credentials( |
115 | | - credential_file: Optional[TextIO] = None, credentials: Optional[dict[str, Any]] = None |
116 | | -) -> GoogleCredentials: |
| 89 | +def get_credentials(credential_file: Optional[TextIO] = None, credentials: Optional[dict[str, Any]] = None) -> Credentials: |
| 90 | + SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] |
| 91 | + |
| 92 | + # File takes precedence over credentials dict |
117 | 93 | if credential_file: |
118 | | - return GoogleCredentials.from_stream(credential_file) |
| 94 | + cred_data = credential_file.read() |
| 95 | + credentials = json.loads(cred_data) |
119 | 96 |
|
120 | | - if credentials and credentials["type"] == "service_account": |
121 | | - return ServiceAccountCredentials_from_dict( |
122 | | - credentials, |
123 | | - scopes=["https://www.googleapis.com/auth/cloud-platform"], |
124 | | - ) |
| 97 | + # project_id is the second element of the returned tuple |
| 98 | + if credentials: |
| 99 | + # Ensure service account credentials have required fields with defaults |
| 100 | + if credentials.get("type") == "service_account": |
| 101 | + credentials = _ensure_service_account_fields(credentials) |
| 102 | + gcreds, _ = google.auth.load_credentials_from_dict(credentials, scopes=SCOPES) |
| 103 | + else: |
| 104 | + gcreds, _ = google.auth.default(scopes=SCOPES) |
| 105 | + return gcreds |
125 | 106 |
|
126 | | - if credentials and credentials["type"] == "authorized_user": |
127 | | - return GoogleCredentials( |
128 | | - access_token=None, |
129 | | - client_id=credentials["client_id"], |
130 | | - client_secret=credentials["client_secret"], |
131 | | - refresh_token=credentials["refresh_token"], |
132 | | - token_expiry=None, |
133 | | - token_uri=GOOGLE_TOKEN_URI, |
134 | | - user_agent="pghoard", |
135 | | - ) |
136 | 107 |
|
137 | | - return GoogleCredentials.get_application_default() |
| 108 | +def _ensure_service_account_fields(credentials: dict[str, Any]) -> dict[str, Any]: |
| 109 | + """Ensure service account credentials have all required fields with sensible defaults.""" |
| 110 | + creds = credentials.copy() |
| 111 | + |
| 112 | + required_defaults = { |
| 113 | + "auth_uri": "https://accounts.google.com/o/oauth2/auth", |
| 114 | + "token_uri": "https://oauth2.googleapis.com/token", |
| 115 | + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", |
| 116 | + } |
| 117 | + |
| 118 | + for field, default_value in required_defaults.items(): |
| 119 | + if field not in creds: |
| 120 | + creds[field] = default_value |
| 121 | + |
| 122 | + return creds |
138 | 123 |
|
139 | 124 |
|
140 | 125 | def base64_to_hex(b64val: Union[str, bytes]) -> str: |
@@ -260,7 +245,7 @@ def _init_google_client(self) -> StorageResource: |
260 | 245 | proxy_pass=self.proxy_info.get("pass"), |
261 | 246 | ) |
262 | 247 |
|
263 | | - http = self.google_creds.authorize(http) |
| 248 | + http = google_auth_httplib2.AuthorizedHttp(self.google_creds, http=http) |
264 | 249 |
|
265 | 250 | try: |
266 | 251 | # sometimes fails: httplib2.ServerNotFoundError: Unable to find the server at www.googleapis.com |
|
0 commit comments