Skip to content

Commit 47e4394

Browse files
authored
feat(storage): Add AsyncConnection class and unit tests (#1664)
Add AsyncConnection class and unit tests
1 parent 27d5e7d commit 47e4394

File tree

3 files changed

+601
-3
lines changed

3 files changed

+601
-3
lines changed
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Create/interact with Google Cloud Storage connections in asynchronous manner."""
16+
17+
import json
18+
import collections
19+
import functools
20+
from urllib.parse import urlencode
21+
22+
import google.api_core.exceptions
23+
from google.cloud import _http
24+
from google.cloud.storage import _http as storage_http
25+
from google.cloud.storage import _helpers
26+
from google.api_core.client_info import ClientInfo
27+
from google.cloud.storage._opentelemetry_tracing import create_trace_span
28+
from google.cloud.storage import __version__
29+
from google.cloud.storage._http import AGENT_VERSION
30+
31+
32+
class AsyncConnection:
33+
"""Class for asynchronous connection using google.auth.aio.
34+
35+
This class handles the creation of API requests, header management,
36+
user agent configuration, and error handling for the Async Storage Client.
37+
38+
Args:
39+
client: The client that owns this connection.
40+
client_info: Information about the client library.
41+
api_endpoint: The API endpoint to use.
42+
"""
43+
44+
def __init__(self, client, client_info=None, api_endpoint=None):
45+
self._client = client
46+
47+
if client_info is None:
48+
client_info = ClientInfo()
49+
50+
self._client_info = client_info
51+
if self._client_info.user_agent is None:
52+
self._client_info.user_agent = AGENT_VERSION
53+
else:
54+
self._client_info.user_agent = (
55+
f"{self._client_info.user_agent} {AGENT_VERSION}"
56+
)
57+
self._client_info.client_library_version = __version__
58+
self._extra_headers = {}
59+
60+
self.API_BASE_URL = api_endpoint or storage_http.Connection.DEFAULT_API_ENDPOINT
61+
self.API_VERSION = storage_http.Connection.API_VERSION
62+
self.API_URL_TEMPLATE = storage_http.Connection.API_URL_TEMPLATE
63+
64+
@property
65+
def extra_headers(self):
66+
"""Returns extra headers to send with every request."""
67+
return self._extra_headers
68+
69+
@extra_headers.setter
70+
def extra_headers(self, value):
71+
"""Set the extra header property."""
72+
self._extra_headers = value
73+
74+
@property
75+
def async_http(self):
76+
"""Returns the AsyncAuthorizedSession from the client.
77+
78+
Returns:
79+
google.auth.aio.transport.sessions.AsyncAuthorizedSession: The async session.
80+
"""
81+
return self._client.async_http
82+
83+
@property
84+
def user_agent(self):
85+
"""Returns user_agent for async HTTP transport.
86+
87+
Returns:
88+
str: The user agent string.
89+
"""
90+
return self._client_info.to_user_agent()
91+
92+
@user_agent.setter
93+
def user_agent(self, value):
94+
"""Setter for user_agent in connection."""
95+
self._client_info.user_agent = value
96+
97+
async def _make_request(
98+
self,
99+
method,
100+
url,
101+
data=None,
102+
content_type=None,
103+
headers=None,
104+
target_object=None,
105+
timeout=_http._DEFAULT_TIMEOUT,
106+
extra_api_info=None,
107+
):
108+
"""A low level method to send a request to the API.
109+
110+
Args:
111+
method (str): The HTTP method (e.g., 'GET', 'POST').
112+
url (str): The specific API URL.
113+
data (Optional[Union[str, bytes, dict]]): The body of the request.
114+
content_type (Optional[str]): The Content-Type header.
115+
headers (Optional[dict]): Additional headers for the request.
116+
target_object (Optional[object]): (Unused in async impl) Reference to the target object.
117+
timeout (Optional[float]): The timeout in seconds.
118+
extra_api_info (Optional[str]): Extra info for the User-Agent / Client-Info.
119+
120+
Returns:
121+
google.auth.aio.transport.Response: The HTTP response object.
122+
"""
123+
headers = headers.copy() if headers else {}
124+
headers.update(self.extra_headers)
125+
headers["Accept-Encoding"] = "gzip"
126+
127+
if content_type:
128+
headers["Content-Type"] = content_type
129+
130+
if extra_api_info:
131+
headers[_http.CLIENT_INFO_HEADER] = f"{self.user_agent} {extra_api_info}"
132+
else:
133+
headers[_http.CLIENT_INFO_HEADER] = self.user_agent
134+
headers["User-Agent"] = self.user_agent
135+
136+
return await self._do_request(
137+
method, url, headers, data, target_object, timeout=timeout
138+
)
139+
140+
async def _do_request(
141+
self, method, url, headers, data, target_object, timeout=_http._DEFAULT_TIMEOUT
142+
):
143+
"""Low-level helper: perform the actual API request.
144+
145+
Args:
146+
method (str): HTTP method.
147+
url (str): API URL.
148+
headers (dict): HTTP headers.
149+
data (Optional[bytes]): Request body.
150+
target_object: Unused in this implementation, kept for compatibility.
151+
timeout (float): Request timeout.
152+
153+
Returns:
154+
google.auth.aio.transport.Response: The response object.
155+
"""
156+
return await self.async_http.request(
157+
method=method,
158+
url=url,
159+
headers=headers,
160+
data=data,
161+
timeout=timeout,
162+
)
163+
164+
async def api_request(self, *args, **kwargs):
165+
"""Perform an API request with retry and tracing support.
166+
167+
Args:
168+
*args: Positional arguments passed to _perform_api_request.
169+
**kwargs: Keyword arguments passed to _perform_api_request.
170+
Can include 'retry' (an AsyncRetry object).
171+
172+
Returns:
173+
Union[dict, bytes]: The parsed JSON response or raw bytes.
174+
"""
175+
retry = kwargs.pop("retry", None)
176+
invocation_id = _helpers._get_invocation_id()
177+
kwargs["extra_api_info"] = invocation_id
178+
span_attributes = {
179+
"gccl-invocation-id": invocation_id,
180+
}
181+
182+
call = functools.partial(self._perform_api_request, *args, **kwargs)
183+
184+
with create_trace_span(
185+
name="Storage.AsyncConnection.api_request",
186+
attributes=span_attributes,
187+
client=self._client,
188+
api_request=kwargs,
189+
retry=retry,
190+
):
191+
if retry:
192+
# Ensure the retry policy checks its conditions
193+
try:
194+
retry = retry.get_retry_policy_if_conditions_met(**kwargs)
195+
except AttributeError:
196+
pass
197+
if retry:
198+
call = retry(call)
199+
return await call()
200+
201+
def build_api_url(
202+
self, path, query_params=None, api_base_url=None, api_version=None
203+
):
204+
"""Construct an API URL.
205+
206+
Args:
207+
path (str): The API path (e.g. '/b/bucket-name').
208+
query_params (Optional[Union[dict, list]]): Query parameters.
209+
api_base_url (Optional[str]): Base URL override.
210+
api_version (Optional[str]): API version override.
211+
212+
Returns:
213+
str: The fully constructed URL.
214+
"""
215+
url = self.API_URL_TEMPLATE.format(
216+
api_base_url=(api_base_url or self.API_BASE_URL),
217+
api_version=(api_version or self.API_VERSION),
218+
path=path,
219+
)
220+
221+
query_params = query_params or {}
222+
223+
if isinstance(query_params, collections.abc.Mapping):
224+
query_params = query_params.copy()
225+
else:
226+
query_params_dict = collections.defaultdict(list)
227+
for key, value in query_params:
228+
query_params_dict[key].append(value)
229+
query_params = query_params_dict
230+
231+
query_params.setdefault("prettyPrint", "false")
232+
233+
url += "?" + urlencode(query_params, doseq=True)
234+
235+
return url
236+
237+
async def _perform_api_request(
238+
self,
239+
method,
240+
path,
241+
query_params=None,
242+
data=None,
243+
content_type=None,
244+
headers=None,
245+
api_base_url=None,
246+
api_version=None,
247+
expect_json=True,
248+
_target_object=None,
249+
timeout=_http._DEFAULT_TIMEOUT,
250+
extra_api_info=None,
251+
):
252+
"""Internal helper to prepare the URL/Body and execute the request.
253+
254+
This method handles JSON serialization of the body, URL construction,
255+
and converts HTTP errors into google.api_core.exceptions.
256+
257+
Args:
258+
method (str): HTTP method.
259+
path (str): URL path.
260+
query_params (Optional[dict]): Query params.
261+
data (Optional[Union[dict, bytes]]): Request body.
262+
content_type (Optional[str]): Content-Type header.
263+
headers (Optional[dict]): HTTP headers.
264+
api_base_url (Optional[str]): Base URL override.
265+
api_version (Optional[str]): API version override.
266+
expect_json (bool): If True, parses response as JSON. Defaults to True.
267+
_target_object: Internal use (unused here).
268+
timeout (float): Request timeout.
269+
extra_api_info (Optional[str]): Extra client info.
270+
271+
Returns:
272+
Union[dict, bytes]: Parsed JSON or raw bytes.
273+
274+
Raises:
275+
google.api_core.exceptions.GoogleAPICallError: If the API returns an error.
276+
"""
277+
url = self.build_api_url(
278+
path=path,
279+
query_params=query_params,
280+
api_base_url=api_base_url,
281+
api_version=api_version,
282+
)
283+
284+
if data and isinstance(data, dict):
285+
data = json.dumps(data)
286+
content_type = "application/json"
287+
288+
response = await self._make_request(
289+
method=method,
290+
url=url,
291+
data=data,
292+
content_type=content_type,
293+
headers=headers,
294+
target_object=_target_object,
295+
timeout=timeout,
296+
extra_api_info=extra_api_info,
297+
)
298+
299+
# Handle API Errors
300+
if not (200 <= response.status_code < 300):
301+
content = await response.read()
302+
payload = {}
303+
if content:
304+
try:
305+
payload = json.loads(content.decode("utf-8"))
306+
except (ValueError, UnicodeDecodeError):
307+
payload = {
308+
"error": {"message": content.decode("utf-8", errors="replace")}
309+
}
310+
raise google.api_core.exceptions.format_http_response_error(
311+
response, method, url, payload
312+
)
313+
314+
# Handle Success
315+
payload = await response.read()
316+
if expect_json:
317+
if not payload:
318+
return {}
319+
return json.loads(payload)
320+
else:
321+
return payload

google/cloud/storage/_http.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from google.cloud.storage import _helpers
2121
from google.cloud.storage._opentelemetry_tracing import create_trace_span
2222

23+
AGENT_VERSION = f"gcloud-python/{__version__}"
24+
2325

2426
class Connection(_http.JSONConnection):
2527
"""A connection to Google Cloud Storage via the JSON REST API.
@@ -54,9 +56,8 @@ def __init__(self, client, client_info=None, api_endpoint=None):
5456
# TODO: When metrics all use gccl, this should be removed #9552
5557
if self._client_info.user_agent is None: # pragma: no branch
5658
self._client_info.user_agent = ""
57-
agent_version = f"gcloud-python/{__version__}"
58-
if agent_version not in self._client_info.user_agent:
59-
self._client_info.user_agent += f" {agent_version} "
59+
if AGENT_VERSION not in self._client_info.user_agent:
60+
self._client_info.user_agent += f" {AGENT_VERSION} "
6061

6162
API_VERSION = _helpers._API_VERSION
6263
"""The version of the API, used in building the API call's URL."""

0 commit comments

Comments
 (0)