|
35 | 35 |
|
36 | 36 | from __future__ import annotations |
37 | 37 |
|
38 | | -from typing import TYPE_CHECKING, cast |
| 38 | +from typing import TYPE_CHECKING, TypeVar, cast |
39 | 39 |
|
40 | 40 | from net_addr import ( |
41 | 41 | Ip4Address, |
42 | 42 | Ip4AddressFormatError, |
43 | 43 | Ip6Address, |
44 | 44 | Ip6AddressFormatError, |
45 | 45 | ) |
| 46 | +from net_addr import IpAddress |
46 | 47 | from pytcp import stack |
47 | 48 |
|
48 | 49 | if TYPE_CHECKING: |
49 | | - from net_addr import IpAddress |
50 | 50 | from pytcp.socket.socket import AddressFamily, SocketType |
51 | 51 |
|
| 52 | +T = TypeVar("T", bound=IpAddress) |
| 53 | + |
52 | 54 |
|
53 | 55 | EPHEMERAL_PORT_RANGE = range(32168, 60700, 2) |
54 | 56 |
|
@@ -85,20 +87,19 @@ def str_to_ip( |
85 | 87 | return None |
86 | 88 |
|
87 | 89 |
|
88 | | -def pick_local_ip_address( |
89 | | - remote_ip_address: IpAddress, |
90 | | -) -> Ip6Address | Ip4Address: |
| 90 | +def pick_local_ip_address(remote_ip_address: T) -> T: |
91 | 91 | """ |
92 | 92 | Pick appropriate source IP address based on provided |
93 | 93 | destination IP address. |
94 | 94 | """ |
95 | 95 |
|
96 | | - assert isinstance(remote_ip_address, (Ip6Address, Ip4Address)) |
97 | | - |
98 | 96 | if isinstance(remote_ip_address, Ip6Address): |
99 | 97 | return pick_local_ip6_address(remote_ip_address) |
100 | 98 |
|
101 | | - return pick_local_ip4_address(remote_ip_address) |
| 99 | + if isinstance(remote_ip_address, Ip4Address): |
| 100 | + return pick_local_ip4_address(remote_ip_address) |
| 101 | + |
| 102 | + raise TypeError(f"Unsupported IP address type: {type(remote_ip_address)}") |
102 | 103 |
|
103 | 104 |
|
104 | 105 | def pick_local_ip6_address( |
|
0 commit comments