Skip to content

Commit 76444c8

Browse files
committed
tackle feedback
Signed-off-by: Filinto Duran <[email protected]>
1 parent 02ef910 commit 76444c8

File tree

4 files changed

+111
-52
lines changed

4 files changed

+111
-52
lines changed

durabletask/aio/internal/shared.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) The Dapr Authors.
22
# Licensed under the MIT License.
33

4-
from typing import Dict, Optional, Sequence, Union
4+
from typing import Optional, Sequence, Union
55

66
import grpc
77
from grpc import aio as grpc_aio
@@ -51,19 +51,16 @@ def get_grpc_aio_channel(
5151
host_address = host_address[len(protocol) :]
5252
break
5353

54-
# channel interceptors/options
55-
channel_kwargs: Dict[str, ChannelArgumentType | Sequence[ClientInterceptor]] = dict(
56-
interceptors=interceptors
57-
)
5854
if options is not None:
5955
validate_grpc_options(options)
60-
channel_kwargs["options"] = options
6156

6257
if secure_channel:
6358
channel = grpc_aio.secure_channel(
64-
host_address, grpc.ssl_channel_credentials(), **channel_kwargs
59+
host_address, grpc.ssl_channel_credentials(), interceptors=interceptors, options=options
6560
)
6661
else:
67-
channel = grpc_aio.insecure_channel(host_address, **channel_kwargs)
62+
channel = grpc_aio.insecure_channel(
63+
host_address, interceptors=interceptors, options=options
64+
)
6865

6966
return channel

durabletask/internal/shared.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,11 @@ def get_grpc_channel(
9797
if options is not None:
9898
# validate all options keys prefix starts with `grpc.`
9999
validate_grpc_options(options)
100-
if secure_channel:
101-
channel = grpc.secure_channel(
102-
host_address, grpc.ssl_channel_credentials(), options=options
103-
)
104-
else:
105-
channel = grpc.insecure_channel(host_address, options=options)
100+
101+
if secure_channel:
102+
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options)
106103
else:
107-
if secure_channel:
108-
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
109-
else:
110-
channel = grpc.insecure_channel(host_address)
104+
channel = grpc.insecure_channel(host_address, options=options)
111105

112106
# Apply interceptors ONLY if they exist
113107
if interceptors:

tests/durabletask/test_client.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import ANY, patch
1+
from unittest.mock import patch
22

33
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
44
from durabletask.internal.shared import get_default_host_address, get_grpc_channel
@@ -11,7 +11,9 @@
1111
def test_get_grpc_channel_insecure():
1212
with patch("grpc.insecure_channel") as mock_channel:
1313
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
14-
mock_channel.assert_called_once_with(HOST_ADDRESS)
14+
args, kwargs = mock_channel.call_args
15+
assert args[0] == HOST_ADDRESS
16+
assert "options" in kwargs and kwargs["options"] is None
1517

1618

1719
def test_get_grpc_channel_secure():
@@ -20,13 +22,18 @@ def test_get_grpc_channel_secure():
2022
patch("grpc.ssl_channel_credentials") as mock_credentials,
2123
):
2224
get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS)
23-
mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value)
25+
args, kwargs = mock_channel.call_args
26+
assert args[0] == HOST_ADDRESS
27+
assert args[1] == mock_credentials.return_value
28+
assert "options" in kwargs and kwargs["options"] is None
2429

2530

2631
def test_get_grpc_channel_default_host_address():
2732
with patch("grpc.insecure_channel") as mock_channel:
2833
get_grpc_channel(None, False, interceptors=INTERCEPTORS)
29-
mock_channel.assert_called_once_with(get_default_host_address())
34+
args, kwargs = mock_channel.call_args
35+
assert args[0] == get_default_host_address()
36+
assert "options" in kwargs and kwargs["options"] is None
3037

3138

3239
def test_get_grpc_channel_with_metadata():
@@ -35,7 +42,9 @@ def test_get_grpc_channel_with_metadata():
3542
patch("grpc.intercept_channel") as mock_intercept_channel,
3643
):
3744
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
38-
mock_channel.assert_called_once_with(HOST_ADDRESS)
45+
args, kwargs = mock_channel.call_args
46+
assert args[0] == HOST_ADDRESS
47+
assert "options" in kwargs and kwargs["options"] is None
3948
mock_intercept_channel.assert_called_once()
4049

4150
# Capture and check the arguments passed to intercept_channel()
@@ -54,40 +63,60 @@ def test_grpc_channel_with_host_name_protocol_stripping():
5463

5564
prefix = "grpc://"
5665
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
57-
mock_insecure_channel.assert_called_with(host_name)
66+
args, kwargs = mock_insecure_channel.call_args
67+
assert args[0] == host_name
68+
assert "options" in kwargs and kwargs["options"] is None
5869

5970
prefix = "http://"
6071
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
61-
mock_insecure_channel.assert_called_with(host_name)
72+
args, kwargs = mock_insecure_channel.call_args
73+
assert args[0] == host_name
74+
assert "options" in kwargs and kwargs["options"] is None
6275

6376
prefix = "HTTP://"
6477
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
65-
mock_insecure_channel.assert_called_with(host_name)
78+
args, kwargs = mock_insecure_channel.call_args
79+
assert args[0] == host_name
80+
assert "options" in kwargs and kwargs["options"] is None
6681

6782
prefix = "GRPC://"
6883
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
69-
mock_insecure_channel.assert_called_with(host_name)
84+
args, kwargs = mock_insecure_channel.call_args
85+
assert args[0] == host_name
86+
assert "options" in kwargs and kwargs["options"] is None
7087

7188
prefix = ""
7289
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
73-
mock_insecure_channel.assert_called_with(host_name)
90+
args, kwargs = mock_insecure_channel.call_args
91+
assert args[0] == host_name
92+
assert "options" in kwargs and kwargs["options"] is None
7493

7594
prefix = "grpcs://"
7695
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
77-
mock_secure_channel.assert_called_with(host_name, ANY)
96+
args, kwargs = mock_secure_channel.call_args
97+
assert args[0] == host_name
98+
assert "options" in kwargs and kwargs["options"] is None
7899

79100
prefix = "https://"
80101
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
81-
mock_secure_channel.assert_called_with(host_name, ANY)
102+
args, kwargs = mock_secure_channel.call_args
103+
assert args[0] == host_name
104+
assert "options" in kwargs and kwargs["options"] is None
82105

83106
prefix = "HTTPS://"
84107
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
85-
mock_secure_channel.assert_called_with(host_name, ANY)
108+
args, kwargs = mock_secure_channel.call_args
109+
assert args[0] == host_name
110+
assert "options" in kwargs and kwargs["options"] is None
86111

87112
prefix = "GRPCS://"
88113
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
89-
mock_secure_channel.assert_called_with(host_name, ANY)
114+
args, kwargs = mock_secure_channel.call_args
115+
assert args[0] == host_name
116+
assert "options" in kwargs and kwargs["options"] is None
90117

91118
prefix = ""
92119
get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS)
93-
mock_secure_channel.assert_called_with(host_name, ANY)
120+
args, kwargs = mock_secure_channel.call_args
121+
assert args[0] == host_name
122+
assert "options" in kwargs and kwargs["options"] is None

tests/durabletask/test_client_async.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) The Dapr Authors.
22
# Licensed under the MIT License.
33

4-
from unittest.mock import ANY, patch
4+
from unittest.mock import patch
55

66
from durabletask.aio.client import AsyncTaskHubGrpcClient
77
from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl
@@ -16,7 +16,10 @@
1616
def test_get_grpc_aio_channel_insecure():
1717
with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel:
1818
get_grpc_aio_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS_AIO)
19-
mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=INTERCEPTORS_AIO)
19+
args, kwargs = mock_channel.call_args
20+
assert args[0] == HOST_ADDRESS
21+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
22+
assert "options" in kwargs and kwargs["options"] is None
2023

2124

2225
def test_get_grpc_aio_channel_secure():
@@ -25,23 +28,29 @@ def test_get_grpc_aio_channel_secure():
2528
patch("grpc.ssl_channel_credentials") as mock_credentials,
2629
):
2730
get_grpc_aio_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS_AIO)
28-
mock_channel.assert_called_once_with(
29-
HOST_ADDRESS, mock_credentials.return_value, interceptors=INTERCEPTORS_AIO
30-
)
31+
args, kwargs = mock_channel.call_args
32+
assert args[0] == HOST_ADDRESS
33+
assert args[1] == mock_credentials.return_value
34+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
35+
assert "options" in kwargs and kwargs["options"] is None
3136

3237

3338
def test_get_grpc_aio_channel_default_host_address():
3439
with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel:
3540
get_grpc_aio_channel(None, False, interceptors=INTERCEPTORS_AIO)
36-
mock_channel.assert_called_once_with(
37-
get_default_host_address(), interceptors=INTERCEPTORS_AIO
38-
)
41+
args, kwargs = mock_channel.call_args
42+
assert args[0] == get_default_host_address()
43+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
44+
assert "options" in kwargs and kwargs["options"] is None
3945

4046

4147
def test_get_grpc_aio_channel_with_interceptors():
4248
with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel:
4349
get_grpc_aio_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS_AIO)
44-
mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=INTERCEPTORS_AIO)
50+
args, kwargs = mock_channel.call_args
51+
assert args[0] == HOST_ADDRESS
52+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
53+
assert "options" in kwargs and kwargs["options"] is None
4554

4655
# Capture and check the arguments passed to insecure_channel()
4756
args, kwargs = mock_channel.call_args
@@ -61,43 +70,73 @@ def test_grpc_aio_channel_with_host_name_protocol_stripping():
6170

6271
prefix = "grpc://"
6372
get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO)
64-
mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO)
73+
args, kwargs = mock_insecure_channel.call_args
74+
assert args[0] == host_name
75+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
76+
assert "options" in kwargs and kwargs["options"] is None
6577

6678
prefix = "http://"
6779
get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO)
68-
mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO)
80+
args, kwargs = mock_insecure_channel.call_args
81+
assert args[0] == host_name
82+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
83+
assert "options" in kwargs and kwargs["options"] is None
6984

7085
prefix = "HTTP://"
7186
get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO)
72-
mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO)
87+
args, kwargs = mock_insecure_channel.call_args
88+
assert args[0] == host_name
89+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
90+
assert "options" in kwargs and kwargs["options"] is None
7391

7492
prefix = "GRPC://"
7593
get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO)
76-
mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO)
94+
args, kwargs = mock_insecure_channel.call_args
95+
assert args[0] == host_name
96+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
97+
assert "options" in kwargs and kwargs["options"] is None
7798

7899
prefix = ""
79100
get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO)
80-
mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO)
101+
args, kwargs = mock_insecure_channel.call_args
102+
assert args[0] == host_name
103+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
104+
assert "options" in kwargs and kwargs["options"] is None
81105

82106
prefix = "grpcs://"
83107
get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO)
84-
mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO)
108+
args, kwargs = mock_secure_channel.call_args
109+
assert args[0] == host_name
110+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
111+
assert "options" in kwargs and kwargs["options"] is None
85112

86113
prefix = "https://"
87114
get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO)
88-
mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO)
115+
args, kwargs = mock_secure_channel.call_args
116+
assert args[0] == host_name
117+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
118+
assert "options" in kwargs and kwargs["options"] is None
89119

90120
prefix = "HTTPS://"
91121
get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO)
92-
mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO)
122+
args, kwargs = mock_secure_channel.call_args
123+
assert args[0] == host_name
124+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
125+
assert "options" in kwargs and kwargs["options"] is None
93126

94127
prefix = "GRPCS://"
95128
get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO)
96-
mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO)
129+
args, kwargs = mock_secure_channel.call_args
130+
assert args[0] == host_name
131+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
132+
assert "options" in kwargs and kwargs["options"] is None
97133

98134
prefix = ""
99135
get_grpc_aio_channel(prefix + host_name, True, interceptors=INTERCEPTORS_AIO)
100-
mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO)
136+
args, kwargs = mock_secure_channel.call_args
137+
assert args[0] == host_name
138+
assert kwargs.get("interceptors") == INTERCEPTORS_AIO
139+
assert "options" in kwargs and kwargs["options"] is None
101140

102141

103142
def test_async_client_construct_with_metadata():

0 commit comments

Comments
 (0)