Skip to content

Commit 0e3d02b

Browse files
committed
Release common v2.0.0
1 parent f1703c5 commit 0e3d02b

File tree

9 files changed

+415
-48
lines changed

9 files changed

+415
-48
lines changed

common/CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Changelog
22

3+
## 2.0.0 - 2025-08-22
4+
5+
### Changed (3)
6+
7+
- Added custom REST headers
8+
- Added `subscribe_user_data`, `on` and `unsubscribe` method to `WebSocketAPIBase`
9+
- Updated `RequestStream` response type to `RequestStreamHandle`
10+
311
## 1.2.0 - 2025-08-07
412

513
### Changed (1)

common/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "binance-common"
3-
version = "1.2.0"
3+
version = "2.0.0"
44
description = "Binance Common Types and Utilities for Binance Connectors"
55
authors = ["Binance"]
66
license = "MIT"

common/src/binance_common/configuration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import ssl
22

3-
from typing import Optional, Dict, Union
3+
from typing import Optional, Dict, Union, List
44
from http.client import HTTPSConnection
55

66
from binance_common.constants import TimeUnit, WebsocketMode
7+
from binance_common.headers import parse_custom_headers
78

89

910
class ConfigurationRestAPI:
@@ -19,6 +20,7 @@ class ConfigurationRestAPI:
1920
- Time Unit
2021
- Proxy support
2122
- Retries & Backoff
23+
- Custom REST headers
2224
"""
2325

2426
def __init__(
@@ -36,6 +38,7 @@ def __init__(
3638
time_unit: Optional[str] = None,
3739
private_key: Optional[Union[bytes, str]] = None,
3840
private_key_passphrase: Optional[str] = None,
41+
custom_headers: Optional[dict[str, Union[str, List[str]]]] = {},
3942
):
4043
"""
4144
Initialize the API configuration.
@@ -54,6 +57,7 @@ def __init__(
5457
time_unit (Optional[str]): Time unit for time-based responses (default: None).
5558
private_key (Optional[Union[bytes, str]]): Private key for authentication (default: None).
5659
private_key_passphrase (Optional[str]): Passphrase for private key (default: None).
60+
custom_headers (Optional[dict[str, Union[str, List[str]]]]): Custom REST headers (default: {}).
5761
"""
5862

5963
self.api_key = api_key
@@ -73,6 +77,7 @@ def __init__(
7377
self.base_headers = {
7478
"Accept": "application/json",
7579
"X-MBX-APIKEY": str(self.api_key) if self.api_key else "",
80+
**parse_custom_headers(custom_headers)
7681
}
7782

7883

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import logging
2+
3+
from typing import List, Optional, Union
4+
5+
def sanitize_header_value(value: Union[str, List[str]]) -> Union[str, List[str]]:
6+
"""Sanitizes a header value by checking for and preventing carriage return and line feed characters.
7+
8+
Args:
9+
value (Union[str, List[str]]): The header value or array of header values to sanitize.
10+
11+
Raises:
12+
ValueError: If the header value contains CR or LF characters.
13+
14+
Returns:
15+
Union[str, List[str]]: The sanitized header value(s).
16+
"""
17+
18+
if "\r" in value or "\n" in value:
19+
raise ValueError(f'Invalid header value (contains CR/LF): "{value}"')
20+
return value
21+
22+
def parse_custom_headers(custom_headers: Optional[dict[str, Union[str, List[str]]]]) -> dict[str, Union[str, List[str]]]:
23+
"""Parses custom headers for the API client.
24+
25+
Args:
26+
custom_headers (Optional[dict[str, Union[str, List[str]]]]): A dictionary of custom headers to parse.
27+
28+
Returns:
29+
dict[str, Union[str, List[str]]]: A dictionary of parsed custom headers.
30+
"""
31+
32+
forbidden_headers = {"host", "authorization", "cookie", ":method", ":path"}
33+
parsed_headers = {}
34+
35+
if not custom_headers:
36+
return {}
37+
38+
for key, value in custom_headers.items():
39+
header_name = key.strip()
40+
if header_name.lower() in forbidden_headers:
41+
logging.warning(f"Header '{header_name}' is not allowed to be set.")
42+
continue
43+
44+
try:
45+
if isinstance(value, list):
46+
parsed_headers[header_name] = [sanitize_header_value(v) for v in value]
47+
else:
48+
parsed_headers[header_name] = sanitize_header_value(value)
49+
except ValueError as e:
50+
logging.warning(f"Dropping header '{header_name}' due to invalid value.")
51+
continue
52+
53+
return parsed_headers

common/src/binance_common/models.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from pydantic import BaseModel
33

44
T = TypeVar("T")
5-
5+
T_Response = TypeVar("T_Response")
6+
T_Stream = TypeVar("T_Stream")
67

78
class RateLimit(BaseModel):
89
"""Represents a single rate limit entry.
@@ -90,3 +91,19 @@ def data(self) -> T:
9091
:return: The parsed data of type T.
9192
"""
9293
return self._data_function()
94+
95+
96+
class WebsocketApiUserDataStreamResponse(Generic[T_Response, T_Stream]):
97+
"""A wrapper for WebSocket API user data stream responses.
98+
99+
:param response: A callable that lazily returns a WebSocket API responses of type T_Response.
100+
:param stream: A callable that lazily returns a RequestStreamHandle responses of type T_Stream.
101+
"""
102+
103+
def __init__(
104+
self,
105+
response: WebsocketApiResponse[T_Response],
106+
stream: T_Stream,
107+
):
108+
self.response = response
109+
self.stream = stream

common/src/binance_common/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def snake_to_camel(snake_str: str) -> str:
103103
str: The converted camelCase string.
104104
"""
105105

106-
parts = snake_str.split('_')
107-
return parts[0] + ''.join(word.capitalize() for word in parts[1:])
106+
parts = snake_str.split("_")
107+
return parts[0] + "".join(word.capitalize() for word in parts[1:])
108108

109109

110110
def make_serializable(val) -> Union[dict, list, str, int, float, bool]:
@@ -117,7 +117,7 @@ def make_serializable(val) -> Union[dict, list, str, int, float, bool]:
117117
"""
118118

119119
if isinstance(val, list):
120-
return [v.__dict__ if hasattr(v, '__dict__') else v for v in val]
120+
return [v.__dict__ if hasattr(v, "__dict__") else v for v in val]
121121
if isinstance(val, bool):
122122
return str(val).lower()
123123
if isinstance(val , Enum):
@@ -544,9 +544,9 @@ def convert(val, expected_type=None):
544544
elif expected_type == str:
545545
return val_stripped
546546
val_lower = val_stripped.lower()
547-
if val_lower == 'true':
547+
if val_lower == "true":
548548
return True
549-
elif val_lower == 'false':
549+
elif val_lower == "false":
550550
return False
551551
try:
552552
return int(val_stripped)

common/src/binance_common/websocket.py

Lines changed: 125 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from collections import OrderedDict
88
from pydantic import BaseModel
9-
from typing import Callable, Optional, Dict, Union, TypeVar, Type
9+
from typing import Callable, Optional, Dict, Generic, Union, TypeVar, Type
1010
from types import SimpleNamespace
1111
from urllib.parse import urlencode
1212

@@ -193,29 +193,35 @@ async def receive_loop(self, connection: WebSocketConnection):
193193
raise ValueError(f"Error received from server: {data['error']}")
194194

195195
stream = data.get("stream")
196-
callbacks = (
197-
connection.stream_callback_map.get(stream) if stream else None
198-
)
199-
response_model = (
200-
connection.response_types.get(stream) if stream else None
201-
)
196+
subscription_id = data.get("subscriptionId")
197+
198+
key = stream or subscription_id
199+
callbacks = connection.stream_callback_map.get(key) if key is not None else None
200+
202201
if callbacks:
203202
try:
204-
for callback in callbacks:
205-
if response_model is None:
206-
callback(data)
207-
else:
208-
data = data["data"]
209-
if isinstance(data, list):
210-
callback(
211-
[response_model.model_validate_json(json.dumps(item)) for item in data]
212-
)
203+
if stream:
204+
response_model = connection.response_types.get(stream)
205+
payload = data["data"] if response_model else data
206+
207+
for callback in callbacks:
208+
if response_model:
209+
if isinstance(payload, list):
210+
parsed = [
211+
response_model.model_validate_json(json.dumps(item))
212+
for item in payload
213+
]
214+
callback(parsed)
215+
else:
216+
callback(response_model.model_validate_json(json.dumps(payload)))
213217
else:
214-
callback(response_model.model_validate_json(json.dumps(data)))
218+
callback(payload)
219+
else:
220+
payload = data["event"]
221+
for callback in callbacks:
222+
callback(payload)
215223
except Exception as e:
216-
raise ValueError(
217-
f"Error in callback for stream {stream}: {e}"
218-
)
224+
raise ValueError(f"Error in callback for key {key}: {e}")
219225
else:
220226
logging.info(f"Received message: {data}")
221227
elif msg.type == aiohttp.WSMsgType.PING:
@@ -723,22 +729,111 @@ async def ping_ws_api(self, connection: WebSocketConnection):
723729

724730
await super().ping(connection)
725731

732+
async def subscribe_user_data(self, id: str):
733+
if self.configuration.mode == WebsocketMode.SINGLE:
734+
connection = self.connections[0]
735+
else:
736+
connection = self.connections[
737+
self.round_robin_index % len(self.connections)
738+
]
739+
self.round_robin_index = (self.round_robin_index + 1) % len(
740+
self.connections
741+
)
742+
global_stream_connections.stream_connections_map[id] = connection
743+
connection.stream_callback_map.update({id: []})
744+
745+
def on(self, event: str, callback: Callable[[T], None], id: str) -> None:
746+
"""Set the callback function for incoming messages on a specific ID.
747+
748+
Args:
749+
event (str): Event type.
750+
callback (Callable): Callback function.
751+
id (str): User Data ID.
752+
"""
753+
754+
if event != "message":
755+
raise ValueError(f"Unsupported event: {event}")
756+
757+
connection = (
758+
global_stream_connections.stream_connections_map[id]
759+
if id in global_stream_connections.stream_connections_map
760+
else None
761+
)
762+
763+
if connection:
764+
connection.stream_callback_map[id].append(callback)
765+
else:
766+
logging.warning(f"Stream {id} not connected.")
767+
768+
async def unsubscribe(self, id: str):
769+
"""Unsubscribe from a user data ID.
770+
771+
Args:
772+
id (str): user data ID to unsubscribe from.
773+
"""
774+
775+
if self.connections is None or len(self.connections) == 0:
776+
logging.warning("No user data connections available for unsubscription.")
777+
return
778+
779+
if id not in global_stream_connections.stream_connections_map:
780+
logging.warning(f"Stream {id} is not subscribed.")
781+
return
782+
783+
connection = (
784+
global_stream_connections.stream_connections_map[id]
785+
if id in global_stream_connections.stream_connections_map
786+
else None
787+
)
788+
if connection:
789+
global_stream_connections.stream_connections_map.pop(id, None)
790+
logging.info(f"Unsubscribed from stream: {id}")
791+
else:
792+
raise ValueError(f"Subscription id {id} not connected.")
793+
794+
795+
class RequestStreamHandle(Generic[T]):
796+
"""A wrapper for Request Stream Method.
797+
798+
:param websocket_base: WebSocket base.
799+
:param stream: Stream name.
800+
:param response_model: The Pydantic model to validate the response data.
801+
"""
802+
803+
def __init__(
804+
self,
805+
websocket_base: WebSocketStreamBase | WebSocketAPIBase,
806+
stream: str,
807+
response_model: Type[T] = None,
808+
):
809+
self._websocket_base = websocket_base
810+
self._stream = stream
811+
self._response_model = response_model
812+
813+
async def unsubscribe(self) -> None:
814+
if isinstance(self._websocket_base, WebSocketStreamBase):
815+
await self._websocket_base.unsubscribe(streams=self._stream)
816+
else:
817+
await self._websocket_base.unsubscribe(id=self._stream)
818+
819+
def on(self, event: str, callback: Callable[[T], None]) -> None:
820+
self._websocket_base.on(event, callback, self._stream)
821+
726822

727823
async def RequestStream(
728-
websocket_base: WebSocketStreamBase, stream: str, response_model: Type[T] = None
729-
) -> SimpleNamespace:
824+
websocket_base: WebSocketStreamBase | WebSocketAPIBase, stream: str, response_model: Type[T] = None
825+
) -> RequestStreamHandle[T]:
730826
"""Decorator to create a request stream for a specific stream.
731827
732828
Args:
733-
websocket_base (WebSocketStreamBase): WebSocket stream base.
829+
websocket_base (WebSocketStreamBase | WebSocketAPIBase): WebSocket base.
734830
stream (str): Stream name.
831+
response_model (Type[T], optional): Response model for the stream.
735832
"""
736-
await websocket_base.subscribe(streams=[stream], response_model=response_model)
737-
738-
def on(event: str, callback: Callable[[T], None]):
739-
websocket_base.on(event, callback, stream)
740833

741-
async def unsubscribe():
742-
await websocket_base.unsubscribe(streams=stream)
834+
if isinstance(websocket_base, WebSocketStreamBase):
835+
await websocket_base.subscribe(streams=[stream], response_model=response_model)
836+
else:
837+
await websocket_base.subscribe_user_data(id=stream)
743838

744-
return SimpleNamespace(on=on, unsubscribe=unsubscribe)
839+
return RequestStreamHandle(websocket_base, stream, response_model)

0 commit comments

Comments
 (0)