44"""Test cases for the channel module."""
55
66from dataclasses import dataclass
7+ from typing import NotRequired , TypedDict
78from unittest import mock
89
910import pytest
1213from frequenz .client .base .channel import parse_grpc_uri
1314
1415VALID_URLS = [
15- ("grpc://localhost" , "localhost" , 9090 , False ),
16- ("grpc://localhost:1234" , "localhost" , 1234 , False ),
16+ ("grpc://localhost" , "localhost" , 9090 , True ),
17+ ("grpc://localhost:1234" , "localhost" , 1234 , True ),
1718 ("grpc://localhost:1234?ssl=true" , "localhost" , 1234 , True ),
1819 ("grpc://localhost:1234?ssl=false" , "localhost" , 1234 , False ),
1920 ("grpc://localhost:1234?ssl=1" , "localhost" , 1234 , True ),
2930]
3031
3132
33+ class _CreateChannelKwargs (TypedDict ):
34+ default_port : NotRequired [int ]
35+ default_ssl : NotRequired [bool ]
36+
37+
3238@pytest .mark .parametrize ("uri, host, port, ssl" , VALID_URLS )
33- def test_grpclib_parse_uri_ok (
39+ @pytest .mark .parametrize (
40+ "default_port" , [None , 9090 , 1234 ], ids = lambda x : f"default_port={ x } "
41+ )
42+ @pytest .mark .parametrize (
43+ "default_ssl" , [None , True , False ], ids = lambda x : f"default_ssl={ x } "
44+ )
45+ def test_grpclib_parse_uri_ok ( # pylint: disable=too-many-arguments
3446 uri : str ,
3547 host : str ,
3648 port : int ,
3749 ssl : bool ,
50+ default_port : int | None ,
51+ default_ssl : bool | None ,
3852) -> None :
3953 """Test successful parsing of gRPC URIs using grpclib."""
4054
@@ -44,24 +58,39 @@ class _FakeChannel:
4458 port : int
4559 ssl : bool
4660
61+ kwargs = _CreateChannelKwargs ()
62+ if default_port is not None :
63+ kwargs ["default_port" ] = default_port
64+ if default_ssl is not None :
65+ kwargs ["default_ssl" ] = default_ssl
66+
67+ expected_port = port if f":{ port } " in uri or default_port is None else default_port
68+ expected_ssl = ssl if "ssl" in uri or default_ssl is None else default_ssl
69+
4770 with mock .patch (
4871 "frequenz.client.base.channel._grpchacks.grpclib_create_channel" ,
4972 return_value = _FakeChannel (host , port , ssl ),
50- ):
51- channel = parse_grpc_uri (uri , _grpchacks .GrpclibChannel )
73+ ) as create_channel_mock :
74+ channel = parse_grpc_uri (uri , _grpchacks .GrpclibChannel , ** kwargs )
5275
5376 assert isinstance (channel , _FakeChannel )
54- assert channel .host == host
55- assert channel .port == port
56- assert channel .ssl == ssl
77+ create_channel_mock .assert_called_once_with (host , expected_port , expected_ssl )
5778
5879
5980@pytest .mark .parametrize ("uri, host, port, ssl" , VALID_URLS )
60- def test_grpcio_parse_uri_ok (
81+ @pytest .mark .parametrize (
82+ "default_port" , [None , 9090 , 1234 ], ids = lambda x : f"default_port={ x } "
83+ )
84+ @pytest .mark .parametrize (
85+ "default_ssl" , [None , True , False ], ids = lambda x : f"default_ssl={ x } "
86+ )
87+ def test_grpcio_parse_uri_ok ( # pylint: disable=too-many-arguments,too-many-locals
6188 uri : str ,
6289 host : str ,
6390 port : int ,
6491 ssl : bool ,
92+ default_port : int | None ,
93+ default_ssl : bool | None ,
6594) -> None :
6695 """Test successful parsing of gRPC URIs using grpcio."""
6796 expected_channel = mock .MagicMock (
@@ -70,6 +99,14 @@ def test_grpcio_parse_uri_ok(
7099 expected_credentials = mock .MagicMock (
71100 name = "mock_credentials" , spec = _grpchacks .GrpcioChannel
72101 )
102+ expected_port = port if f":{ port } " in uri or default_port is None else default_port
103+ expected_ssl = ssl if "ssl" in uri or default_ssl is None else default_ssl
104+
105+ kwargs = _CreateChannelKwargs ()
106+ if default_port is not None :
107+ kwargs ["default_port" ] = default_port
108+ if default_ssl is not None :
109+ kwargs ["default_ssl" ] = default_ssl
73110
74111 with (
75112 mock .patch (
@@ -85,11 +122,11 @@ def test_grpcio_parse_uri_ok(
85122 return_value = expected_credentials ,
86123 ) as ssl_channel_credentials_mock ,
87124 ):
88- channel = parse_grpc_uri (uri , _grpchacks .GrpcioChannel )
125+ channel = parse_grpc_uri (uri , _grpchacks .GrpcioChannel , ** kwargs )
89126
90127 assert channel == expected_channel
91- expected_target = f"{ host } :{ port } "
92- if ssl :
128+ expected_target = f"{ host } :{ expected_port } "
129+ if expected_ssl :
93130 ssl_channel_credentials_mock .assert_called_once_with ()
94131 secure_channel_mock .assert_called_once_with (
95132 expected_target , expected_credentials
0 commit comments