Skip to content

Commit 262b16a

Browse files
authored
Add SocketIO support when generate token (Azure#38588)
* add socketio support when generate token * add documentation --------- Co-authored-by: chuongnguyen <[email protected]>
1 parent 87a8696 commit 262b16a

File tree

5 files changed

+77
-9
lines changed

5 files changed

+77
-9
lines changed

sdk/webpubsub/azure-messaging-webpubsubservice/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## 1.2.2 (Unreleased)
44

55
### Features Added
6+
- Added support for SocketIO when generating client access token
67

78
### Breaking Changes
89

sdk/webpubsub/azure-messaging-webpubsubservice/azure/messaging/webpubsubservice/_operations/_patch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def get_client_access_token(self, *, client_protocol: Optional[str] = "Default",
9696
:keyword groups: Groups that the connection will join when it connects. Default value is None.
9797
:paramtype groups: list[str]
9898
:keyword client_protocol: The type of client protocol. Case-insensitive. If not set, it's "Default". For Web
99-
PubSub for Socket.IO, only the default value is supported. For Web PubSub, the valid values are
100-
'Default' and 'MQTT'. Known values are: "Default" and "MQTT". Default value is "Default".
99+
PubSub for Socket.IO, "SocketIO" type is supported. For Web PubSub, the valid values are
100+
'Default', 'MQTT'. Known values are: "Default", "MQTT" and "SocketIO". Default value is "Default".
101101
:paramtype client_type: str
102102
:returns: JSON response containing the web socket endpoint, the token and a url with the generated access token.
103103
:rtype: JSON
@@ -124,9 +124,14 @@ def get_client_access_token(self, *, client_protocol: Optional[str] = "Default",
124124

125125
client_endpoint = "ws" + endpoint[4:]
126126
hub = self._config.hub
127-
path = "/clients/mqtt/hubs/" if client_protocol.lower() == "mqtt" else "/client/hubs/"
128127
# Example URL for Default Client Type: https://<service-name>.webpubsub.azure.com/client/hubs/<hub>
129-
# and for MQTT Client Type: https://<service-name>.webpubsub.azure.com/clients/mqtt/hubs/<hub>
128+
# MQTT Client Type: https://<service-name>.webpubsub.azure.com/clients/mqtt/hubs/<hub>
129+
# SocketIO Client Type: https://<service-name>.webpubsub.azure.com/clients/socketio/hubs/<hub>
130+
path = "/client/hubs/"
131+
if client_protocol.lower() == "mqtt":
132+
path = "/clients/mqtt/hubs/"
133+
elif client_protocol.lower() == "socketio":
134+
path = "/clients/socketio/hubs/"
130135
client_url = client_endpoint + path + hub
131136
jwt_headers = kwargs.pop("jwt_headers", {})
132137
if isinstance(self._config.credential, AzureKeyCredential):

sdk/webpubsub/azure-messaging-webpubsubservice/azure/messaging/webpubsubservice/aio/_operations/_patch.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ async def get_client_access_token( # pylint: disable=arguments-differ
5858
default value may result in unsupported behavior.
5959
:paramtype api_version: str
6060
:keyword client_protocol: The type of client protocol. Case-insensitive. If not set, it's "Default". For Web
61-
PubSub for Socket.IO, only the default value is supported. For Web PubSub, the valid values are
62-
'Default' and 'MQTT'. Known values are: "Default" and "MQTT". Default value is "Default".
61+
PubSub for Socket.IO, "SocketIO" type is supported. For Web PubSub, the valid values are
62+
'Default', 'MQTT'. Known values are: "Default", "MQTT" and "SocketIO". Default value is "Default".
6363
:paramtype client_type: str
6464
:return: JSON object
6565
:rtype: JSON
@@ -87,7 +87,11 @@ async def get_client_access_token( # pylint: disable=arguments-differ
8787

8888
client_endpoint = "ws" + endpoint[4:]
8989
hub = self._config.hub
90-
path = "/clients/mqtt/hubs/" if client_protocol.lower() == "mqtt" else "/client/hubs/"
90+
path = "/client/hubs/"
91+
if client_protocol.lower() == "mqtt":
92+
path = "/clients/mqtt/hubs/"
93+
elif client_protocol.lower() == "socketio":
94+
path = "/clients/socketio/hubs/"
9195
client_url = client_endpoint + path + hub
9296
if isinstance(self._config.credential, AzureKeyCredential):
9397
token = get_token_by_key(

sdk/webpubsub/azure-messaging-webpubsubservice/tests/test_jwt.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,22 @@ def test_generate_mqtt_token(connection_string, hub, expected_url):
136136

137137
assert len(decoded_token_1) == 3
138138
assert decoded_token_1['aud'] == expected_url.replace('ws', 'http')
139+
140+
test_cases = [
141+
("Endpoint=http://localhost;Port=8080;AccessKey={};Version=1.0;".format(access_key), "hub", "ws://localhost:8080/clients/socketio/hubs/hub"),
142+
("Endpoint=https://a;AccessKey={};Version=1.0;".format(access_key), "hub", "wss://a/clients/socketio/hubs/hub"),
143+
("Endpoint=http://a;AccessKey={};Version=1.0;".format(access_key), "hub", "ws://a/clients/socketio/hubs/hub")
144+
]
145+
@pytest.mark.parametrize("connection_string,hub,expected_url", test_cases)
146+
def test_generate_socketio_token(connection_string, hub, expected_url):
147+
client = WebPubSubServiceClient.from_connection_string(connection_string, hub)
148+
url_1 = client.get_client_access_token(client_protocol="SocketIO")['url']
149+
150+
assert url_1.split("?")[0] == expected_url
151+
152+
token_1 = urlparse(url_1).query[len("access_token="):]
153+
154+
decoded_token_1 = _decode_token(client, token_1, path="/clients/socketio/hubs/hub")
155+
156+
assert len(decoded_token_1) == 3
157+
assert decoded_token_1['aud'] == expected_url.replace('ws', 'http')

sdk/webpubsub/azure-messaging-webpubsubservice/tests/test_jwt_async.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from urllib.parse import urlparse
1616

1717

18-
def _decode_token(client, token):
18+
def _decode_token(client, token, path="/client/hubs/hub"):
1919
return jwt.decode(
2020
token,
2121
client._config.credential.key,
2222
algorithms=["HS256"],
23-
audience=f"{client._config.endpoint}/client/hubs/hub"
23+
audience=client._config.endpoint + path
2424
)
2525

2626

@@ -119,3 +119,42 @@ async def test_pass_in_jwt_headers(connection_string):
119119
kid = '1234567890'
120120
token = (await client.get_client_access_token(jwt_headers={"kid":kid }))['token']
121121
assert jwt.get_unverified_header(token)['kid'] == kid
122+
123+
test_cases = [
124+
("Endpoint=http://localhost;Port=8080;AccessKey={};Version=1.0;".format(access_key), "hub", "ws://localhost:8080/clients/mqtt/hubs/hub"),
125+
("Endpoint=https://a;AccessKey={};Version=1.0;".format(access_key), "hub", "wss://a/clients/mqtt/hubs/hub"),
126+
("Endpoint=http://a;AccessKey={};Version=1.0;".format(access_key), "hub", "ws://a/clients/mqtt/hubs/hub")
127+
]
128+
@pytest.mark.parametrize("connection_string,hub,expected_url", test_cases)
129+
@pytest.mark.asyncio
130+
async def test_generate_mqtt_token(connection_string, hub, expected_url):
131+
client = WebPubSubServiceClient.from_connection_string(connection_string, hub)
132+
url_1 = (await client.get_client_access_token(client_protocol="MQTT"))['url']
133+
assert url_1.split("?")[0] == expected_url
134+
135+
token_1 = urlparse(url_1).query[len("access_token="):]
136+
137+
decoded_token_1 = _decode_token(client, token_1, path="/clients/mqtt/hubs/hub")
138+
139+
assert len(decoded_token_1) == 3
140+
assert decoded_token_1['aud'] == expected_url.replace('ws', 'http')
141+
142+
test_cases = [
143+
("Endpoint=http://localhost;Port=8080;AccessKey={};Version=1.0;".format(access_key), "hub", "ws://localhost:8080/clients/socketio/hubs/hub"),
144+
("Endpoint=https://a;AccessKey={};Version=1.0;".format(access_key), "hub", "wss://a/clients/socketio/hubs/hub"),
145+
("Endpoint=http://a;AccessKey={};Version=1.0;".format(access_key), "hub", "ws://a/clients/socketio/hubs/hub")
146+
]
147+
@pytest.mark.parametrize("connection_string,hub,expected_url", test_cases)
148+
@pytest.mark.asyncio
149+
async def test_generate_socketio_token(connection_string, hub, expected_url):
150+
client = WebPubSubServiceClient.from_connection_string(connection_string, hub)
151+
url_1 = (await client.get_client_access_token(client_protocol="SocketIO"))['url']
152+
153+
assert url_1.split("?")[0] == expected_url
154+
155+
token_1 = urlparse(url_1).query[len("access_token="):]
156+
157+
decoded_token_1 = _decode_token(client, token_1, path="/clients/socketio/hubs/hub")
158+
159+
assert len(decoded_token_1) == 3
160+
assert decoded_token_1['aud'] == expected_url.replace('ws', 'http')

0 commit comments

Comments
 (0)