Skip to content

Commit 86827d6

Browse files
authored
get_request_headers combomethod (#3467)
* get_request_headers combomethod * Add newsfragment
1 parent 3237c6f commit 86827d6

File tree

6 files changed

+31
-10
lines changed

6 files changed

+31
-10
lines changed

newsfragments/3467.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
HTTPProvider and AsyncHTTPProvider's get_request_headers is now available on both the class and the instance

tests/core/providers/test_async_http_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ async def test_async_user_provided_session() -> None:
9494
assert cached_session == session
9595

9696

97-
def test_get_request_headers():
98-
provider = AsyncHTTPProvider()
97+
@pytest.mark.parametrize("provider", (AsyncHTTPProvider(), AsyncHTTPProvider))
98+
def test_get_request_headers(provider):
9999
headers = provider.get_request_headers()
100100
assert len(headers) == 2
101101
assert headers["Content-Type"] == "application/json"

tests/core/providers/test_http_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def test_user_provided_session():
101101
assert adapter._pool_maxsize == 20
102102

103103

104-
def test_get_request_headers():
105-
provider = HTTPProvider()
104+
@pytest.mark.parametrize("provider", (HTTPProvider(), HTTPProvider))
105+
def test_get_request_headers(provider):
106106
headers = provider.get_request_headers()
107107
assert len(headers) == 2
108108
assert headers["Content-Type"] == "application/json"

web3/_utils/http.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
DEFAULT_HTTP_TIMEOUT = 30.0
22

33

4-
def construct_user_agent(class_type: type) -> str:
4+
def construct_user_agent(
5+
module: str,
6+
class_name: str,
7+
) -> str:
58
from web3 import (
69
__version__ as web3_version,
710
)
811

9-
return f"web3.py/{web3_version}/{class_type.__module__}.{class_type.__qualname__}"
12+
return f"web3.py/{web3_version}/{module}.{class_name}"

web3/providers/rpc/async_rpc.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
URI,
2020
)
2121
from eth_utils import (
22+
combomethod,
2223
to_dict,
2324
)
2425

@@ -108,10 +109,17 @@ def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]:
108109
yield "headers", self.get_request_headers()
109110
yield from self._request_kwargs.items()
110111

111-
def get_request_headers(self) -> Dict[str, str]:
112+
@combomethod
113+
def get_request_headers(cls) -> Dict[str, str]:
114+
if isinstance(cls, AsyncHTTPProvider):
115+
cls_name = cls.__class__.__name__
116+
else:
117+
cls_name = cls.__name__
118+
module = cls.__module__
119+
112120
return {
113121
"Content-Type": "application/json",
114-
"User-Agent": construct_user_agent(type(self)),
122+
"User-Agent": construct_user_agent(module, cls_name),
115123
}
116124

117125
async def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes:

web3/providers/rpc/rpc.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
URI,
1717
)
1818
from eth_utils import (
19+
combomethod,
1920
to_dict,
2021
)
2122
import requests
@@ -116,10 +117,18 @@ def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]:
116117
yield "headers", self.get_request_headers()
117118
yield from self._request_kwargs.items()
118119

119-
def get_request_headers(self) -> Dict[str, str]:
120+
@combomethod
121+
def get_request_headers(cls) -> Dict[str, str]:
122+
if isinstance(cls, HTTPProvider):
123+
cls_name = cls.__class__.__name__
124+
else:
125+
cls_name = cls.__name__
126+
127+
module = cls.__module__
128+
120129
return {
121130
"Content-Type": "application/json",
122-
"User-Agent": construct_user_agent(type(self)),
131+
"User-Agent": construct_user_agent(module, cls_name),
123132
}
124133

125134
def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes:

0 commit comments

Comments
 (0)