Skip to content

Commit 70e54a4

Browse files
authored
Feature add connect to requests transport (#87)
1 parent 1948be5 commit 70e54a4

File tree

5 files changed

+323
-57
lines changed

5 files changed

+323
-57
lines changed

gql/client.py

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from inspect import isawaitable
23
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, cast
34

45
from graphql import (
@@ -35,11 +36,6 @@ def __init__(
3536
assert (
3637
not schema
3738
), "Cant fetch the schema from transport if is already provided"
38-
if isinstance(transport, Transport):
39-
# For sync transports, we fetch the schema directly
40-
execution_result = transport.execute(parse(get_introspection_query()))
41-
execution_result = cast(ExecutionResult, execution_result)
42-
introspection = execution_result.data
4339
if introspection:
4440
assert not schema, "Cant provide introspection and schema at the same time"
4541
schema = build_client_schema(introspection)
@@ -68,6 +64,10 @@ def __init__(
6864
# Enforced timeout of the execute function
6965
self.execute_timeout = execute_timeout
7066

67+
if isinstance(transport, Transport) and fetch_schema_from_transport:
68+
with self as session:
69+
session.fetch_schema()
70+
7171
def validate(self, document):
7272
if not self.schema:
7373
raise Exception(
@@ -77,6 +77,10 @@ def validate(self, document):
7777
if validation_errors:
7878
raise validation_errors[0]
7979

80+
def execute_sync(self, document: DocumentNode, *args, **kwargs) -> Dict:
81+
with self as session:
82+
return session.execute(document, *args, **kwargs)
83+
8084
async def execute_async(self, document: DocumentNode, *args, **kwargs) -> Dict:
8185
async with self as session:
8286
return await session.execute(document, *args, **kwargs)
@@ -107,22 +111,7 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
107111
return data
108112

109113
else: # Sync transports
110-
111-
if self.schema:
112-
self.validate(document)
113-
114-
assert self.transport is not None, "Cannot execute without a transport"
115-
116-
result = self.transport.execute(document, *args, **kwargs)
117-
118-
if result.errors:
119-
raise TransportQueryError(str(result.errors[0]))
120-
121-
assert (
122-
result.data is not None
123-
), "Transport returned an ExecutionResult without data or errors"
124-
125-
return result.data
114+
return self.execute_sync(document, *args, **kwargs)
126115

127116
async def subscribe_async(
128117
self, document: DocumentNode, *args, **kwargs
@@ -170,30 +159,72 @@ async def __aenter__(self):
170159
await self.transport.connect()
171160

172161
if not hasattr(self, "session"):
173-
self.session = ClientSession(client=self)
162+
self.session = AsyncClientSession(client=self)
174163

175164
return self.session
176165

177166
async def __aexit__(self, exc_type, exc, tb):
178167

179168
await self.transport.close()
180169

181-
def close(self):
182-
"""Close the client and it's underlying transport (only for Sync transports)"""
183-
if not isinstance(self.transport, AsyncTransport):
184-
self.transport.close()
185-
186170
def __enter__(self):
171+
187172
assert not isinstance(
188173
self.transport, AsyncTransport
189174
), "Only a sync transport can be use. Use 'async with Client(...)' instead"
190-
return self
175+
176+
self.transport.connect()
177+
178+
if not hasattr(self, "session"):
179+
self.session = SyncClientSession(client=self)
180+
181+
return self.session
191182

192183
def __exit__(self, *args):
193-
self.close()
184+
self.transport.close()
185+
186+
187+
class SyncClientSession:
188+
"""An instance of this class is created when using 'with' on the client.
189+
190+
It contains the sync method execute to send queries
191+
with the sync transports.
192+
"""
193+
194+
def __init__(self, client: Client):
195+
self.client = client
196+
197+
def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
198+
199+
# Validate document
200+
if self.client.schema:
201+
self.client.validate(document)
202+
203+
result = self.transport.execute(document, *args, **kwargs)
204+
205+
assert not isawaitable(result), "Transport returned an awaitable result."
206+
result = cast(ExecutionResult, result)
207+
208+
if result.errors:
209+
raise TransportQueryError(str(result.errors[0]))
210+
211+
assert (
212+
result.data is not None
213+
), "Transport returned an ExecutionResult without data or errors"
214+
215+
return result.data
216+
217+
def fetch_schema(self) -> None:
218+
execution_result = self.transport.execute(parse(get_introspection_query()))
219+
self.client.introspection = execution_result.data
220+
self.client.schema = build_client_schema(self.client.introspection)
221+
222+
@property
223+
def transport(self):
224+
return self.client.transport
194225

195226

196-
class ClientSession:
227+
class AsyncClientSession:
197228
"""An instance of this class is created when using 'async with' on the client.
198229
199230
It contains the async methods (execute, subscribe) to send queries
@@ -203,7 +234,7 @@ class ClientSession:
203234
def __init__(self, client: Client):
204235
self.client = client
205236

206-
async def validate(self, document: DocumentNode):
237+
async def fetch_and_validate(self, document: DocumentNode):
207238
"""Fetch schema from transport if needed and validate document.
208239
209240
If no schema is present, the validation will be skipped.
@@ -222,7 +253,7 @@ async def subscribe(
222253
) -> AsyncGenerator[Dict, None]:
223254

224255
# Fetch schema from transport if needed and validate document if possible
225-
await self.validate(document)
256+
await self.fetch_and_validate(document)
226257

227258
# Subscribe to the transport and yield data or raise error
228259
self._generator: AsyncGenerator[
@@ -243,7 +274,7 @@ async def subscribe(
243274
async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
244275

245276
# Fetch schema from transport if needed and validate document if possible
246-
await self.validate(document)
277+
await self.fetch_and_validate(document)
247278

248279
# Execute the query with the transport with a timeout
249280
result = await asyncio.wait_for(

gql/transport/requests.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88

99
from gql.transport import Transport
1010

11+
from .exceptions import (
12+
TransportAlreadyConnected,
13+
TransportClosed,
14+
TransportProtocolError,
15+
TransportServerError,
16+
)
17+
1118

1219
class RequestsHTTPTransport(Transport):
1320
"""Transport to execute GraphQL queries on remote servers.
@@ -58,23 +65,32 @@ def __init__(
5865
self.use_json = use_json
5966
self.default_timeout = timeout
6067
self.verify = verify
68+
self.retries = retries
6169
self.method = method
6270
self.kwargs = kwargs
6371

64-
# Creating a session that can later be re-use to configure custom mechanisms
65-
self.session = requests.Session()
72+
self.session = None
73+
74+
def connect(self):
75+
76+
if self.session is None:
6677

67-
# If we specified some retries, we provide a predefined retry-logic
68-
if retries > 0:
69-
adapter = HTTPAdapter(
70-
max_retries=Retry(
71-
total=retries,
72-
backoff_factor=0.1,
73-
status_forcelist=[500, 502, 503, 504],
78+
# Creating a session that can later be re-use to configure custom mechanisms
79+
self.session = requests.Session()
80+
81+
# If we specified some retries, we provide a predefined retry-logic
82+
if self.retries > 0:
83+
adapter = HTTPAdapter(
84+
max_retries=Retry(
85+
total=self.retries,
86+
backoff_factor=0.1,
87+
status_forcelist=[500, 502, 503, 504],
88+
)
7489
)
75-
)
76-
for prefix in "http://", "https://":
77-
self.session.mount(prefix, adapter)
90+
for prefix in "http://", "https://":
91+
self.session.mount(prefix, adapter)
92+
else:
93+
raise TransportAlreadyConnected("Transport is already connected")
7894

7995
def execute( # type: ignore
8096
self,
@@ -94,6 +110,10 @@ def execute( # type: ignore
94110
`data` is the result of executing the query, `errors` is null
95111
if no errors occurred, and is a non-empty array if an error occurred.
96112
"""
113+
114+
if not self.session:
115+
raise TransportClosed("Transport is not connected")
116+
97117
query_str = print_ast(document)
98118
payload = {"query": query_str, "variables": variable_values or {}}
99119

@@ -116,18 +136,26 @@ def execute( # type: ignore
116136
)
117137
try:
118138
result = response.json()
119-
if not isinstance(result, dict):
120-
raise ValueError
121-
except ValueError:
122-
result = {}
139+
except Exception:
140+
# We raise a TransportServerError if the status code is 400 or higher
141+
# We raise a TransportProtocolError in the other cases
142+
143+
try:
144+
# Raise a requests.HTTPerror if response status is 400 or higher
145+
response.raise_for_status()
146+
147+
except requests.HTTPError as e:
148+
raise TransportServerError(str(e))
149+
150+
raise TransportProtocolError("Server did not return a GraphQL result")
123151

124152
if "errors" not in result and "data" not in result:
125-
response.raise_for_status()
126-
raise requests.HTTPError(
127-
"Server did not return a GraphQL result", response=response
128-
)
153+
raise TransportProtocolError("Server did not return a GraphQL result")
154+
129155
return ExecutionResult(errors=result.get("errors"), data=result.get("data"))
130156

131157
def close(self):
132158
"""Closing the transport by closing the inner session"""
133-
self.session.close()
159+
if self.session:
160+
self.session.close()
161+
self.session = None

gql/transport/transport.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult:
1717
"Any Transport subclass must implement execute method"
1818
) # pragma: no cover
1919

20+
def connect(self):
21+
"""Establish a session with the transport.
22+
"""
23+
pass
24+
2025
def close(self):
2126
"""Close the transport
2227

tests/test_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def test_retries_on_transport(execute_mock):
6060
}
6161
"""
6262
)
63-
with client: # We're using the client as context manager
63+
with client as session: # We're using the client as context manager
6464
with pytest.raises(Exception):
65-
client.execute(query)
65+
session.execute(query)
6666

6767
# This might look strange compared to the previous test, but making 3 retries
6868
# means you're actually doing 4 calls.
@@ -98,7 +98,6 @@ def test_execute_result_error():
9898

9999
with pytest.raises(Exception) as exc_info:
100100
client.execute(failing_query)
101-
client.close()
102101
assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value)
103102

104103

0 commit comments

Comments
 (0)