Skip to content

Commit fcab87d

Browse files
authored
feat(go-client): introduce Equal APIs for HostPort and RPCAddress (#2360)
#2358 Add `Equal` interfaces for the `HostPort` and `RPCAddress` structures to enable equality comparisons based on attributes such as host and port.
1 parent 848ed89 commit fcab87d

File tree

4 files changed

+192
-54
lines changed

4 files changed

+192
-54
lines changed

go-client/idl/base/host_port.go

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func NewHostPort(host string, port uint16) *HostPort {
4848
}
4949
}
5050

51-
func (r *HostPort) Read(iprot thrift.TProtocol) error {
51+
func (hp *HostPort) Read(iprot thrift.TProtocol) error {
5252
host, err := iprot.ReadString()
5353
if err != nil {
5454
return err
@@ -62,43 +62,68 @@ func (r *HostPort) Read(iprot thrift.TProtocol) error {
6262
return err
6363
}
6464

65-
r.host = host
66-
r.port = uint16(port)
67-
r.hpType = HostPortType(hpType)
65+
hp.host = host
66+
hp.port = uint16(port)
67+
hp.hpType = HostPortType(hpType)
6868
return nil
6969
}
7070

71-
func (r *HostPort) Write(oprot thrift.TProtocol) error {
72-
err := oprot.WriteString(r.host)
71+
func (hp *HostPort) Write(oprot thrift.TProtocol) error {
72+
err := oprot.WriteString(hp.host)
7373
if err != nil {
7474
return err
7575
}
76-
err = oprot.WriteI16(int16(r.port))
76+
err = oprot.WriteI16(int16(hp.port))
7777
if err != nil {
7878
return err
7979
}
80-
err = oprot.WriteByte(int8(r.hpType))
80+
err = oprot.WriteByte(int8(hp.hpType))
8181
if err != nil {
8282
return err
8383
}
8484
return nil
8585
}
8686

87-
func (r *HostPort) GetHost() string {
88-
return r.host
87+
func (hp *HostPort) GetHost() string {
88+
return hp.host
8989
}
9090

91-
func (r *HostPort) GetPort() uint16 {
92-
return r.port
91+
func (hp *HostPort) GetPort() uint16 {
92+
return hp.port
9393
}
9494

95-
func (r *HostPort) String() string {
96-
if r == nil {
95+
func (hp *HostPort) String() string {
96+
if hp == nil {
9797
return "<nil>"
9898
}
99-
return fmt.Sprintf("HostPort(%s:%d)", r.host, r.port)
99+
return fmt.Sprintf("HostPort(%s:%d)", hp.host, hp.port)
100100
}
101101

102-
func (r *HostPort) GetHostPort() string {
103-
return fmt.Sprintf("%s:%d", r.host, r.port)
102+
func (hp *HostPort) GetHostPort() string {
103+
return fmt.Sprintf("%s:%d", hp.host, hp.port)
104+
}
105+
106+
func (hp *HostPort) Equal(other *HostPort) bool {
107+
if hp == other {
108+
return true
109+
}
110+
111+
if hp == nil || other == nil {
112+
return false
113+
}
114+
115+
if hp.hpType != other.hpType {
116+
return false
117+
}
118+
119+
switch hp.hpType {
120+
case HOST_TYPE_IPV4:
121+
return hp.host == other.host &&
122+
hp.port == other.port
123+
case HOST_TYPE_GROUP:
124+
// TODO(wangdan): support HOST_TYPE_GROUP.
125+
return false
126+
default:
127+
return true
128+
}
104129
}

go-client/idl/base/host_port_test.go

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,96 @@
2020
package base
2121

2222
import (
23+
"fmt"
2324
"testing"
2425

2526
"github.com/apache/thrift/lib/go/thrift"
2627
"github.com/stretchr/testify/assert"
2728
)
2829

30+
func testNewHostPort(t *testing.T, host string, port uint16) *HostPort {
31+
hp := NewHostPort(host, port)
32+
assert.Equal(t, host, hp.GetHost())
33+
assert.Equal(t, port, hp.GetPort())
34+
assert.True(t, hp.Equal(hp))
35+
36+
return hp
37+
}
38+
39+
func stringify(host string, port uint16) string {
40+
return fmt.Sprintf("<%s:%d>", host, port)
41+
}
42+
2943
func TestHostPort(t *testing.T) {
30-
testCases := map[string]uint16{
31-
"localhost": 8080,
44+
tests := map[string]uint16{
45+
"localhost": 8080,
46+
"pegasus.apache.org": 443,
3247
}
3348

34-
for host, port := range testCases {
35-
hp := NewHostPort(host, port)
36-
assert.Equal(t, host, hp.GetHost())
37-
assert.Equal(t, port, hp.GetPort())
49+
runner := func(host string, port uint16) func(t *testing.T) {
50+
return func(t *testing.T) {
51+
t.Parallel()
52+
53+
hp := testNewHostPort(t, host, port)
3854

39-
// test HostPort serialize
40-
buf := thrift.NewTMemoryBuffer()
41-
oprot := thrift.NewTBinaryProtocolTransport(buf)
42-
hp.Write(oprot)
55+
// Test serialization.
56+
buf := thrift.NewTMemoryBuffer()
57+
oprot := thrift.NewTBinaryProtocolTransport(buf)
58+
assert.NoError(t, hp.Write(oprot))
4359

44-
// test HostPort deserialize
45-
readHostPort := NewHostPort("", 0)
46-
readHostPort.Read(oprot)
60+
// Test deserialization.
61+
peer := NewHostPort("", 0)
62+
assert.NoError(t, peer.Read(oprot))
63+
assert.True(t, peer.Equal(peer))
64+
65+
// Test equality.
66+
assert.Equal(t, hp, peer)
67+
assert.True(t, hp.Equal(peer))
68+
assert.True(t, peer.Equal(hp))
69+
}
70+
}
71+
72+
for host, port := range tests {
73+
t.Run(stringify(host, port), runner(host, port))
74+
}
75+
}
76+
77+
func TestHostPortEquality(t *testing.T) {
78+
type hpCase struct {
79+
host string
80+
port uint16
81+
}
82+
type testCase struct {
83+
x hpCase
84+
y hpCase
85+
equal bool
86+
}
87+
tests := []testCase{
88+
{hpCase{"localhost", 8080}, hpCase{"localhost", 8080}, true},
89+
{hpCase{"localhost", 8080}, hpCase{"pegasus.apache.org", 8080}, false},
90+
{hpCase{"localhost", 8080}, hpCase{"localhost", 8081}, false},
91+
}
92+
93+
testName := func(hpX hpCase, hpY hpCase) string {
94+
hpName := func(hp hpCase) string {
95+
return stringify(hp.host, hp.port)
96+
}
97+
return fmt.Sprintf("%s-vs-%s", hpName(hpX), hpName(hpY))
98+
}
99+
100+
runner := func(test testCase) func(t *testing.T) {
101+
return func(t *testing.T) {
102+
t.Parallel()
103+
104+
hpX := testNewHostPort(t, test.x.host, test.x.port)
105+
hpY := testNewHostPort(t, test.y.host, test.y.port)
106+
107+
assert.Equal(t, test.equal, hpX.Equal(hpY))
108+
assert.Equal(t, test.equal, hpY.Equal(hpX))
109+
}
110+
}
47111

48-
// check equals
49-
assert.Equal(t, readHostPort, hp)
112+
for _, test := range tests {
113+
t.Run(testName(test.x, test.y), runner(test))
50114
}
51115
}

go-client/idl/base/rpc_address.go

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,38 +37,42 @@ func NewRPCAddress(ip net.IP, port int) *RPCAddress {
3737
}
3838
}
3939

40-
func (r *RPCAddress) Read(iprot thrift.TProtocol) error {
40+
func (a *RPCAddress) Read(iprot thrift.TProtocol) error {
4141
address, err := iprot.ReadI64()
4242
if err != nil {
4343
return err
4444
}
45-
r.address = address
45+
a.address = address
4646
return nil
4747
}
4848

49-
func (r *RPCAddress) Write(oprot thrift.TProtocol) error {
50-
return oprot.WriteI64(r.address)
49+
func (a *RPCAddress) Write(oprot thrift.TProtocol) error {
50+
return oprot.WriteI64(a.address)
5151
}
5252

53-
func (r *RPCAddress) String() string {
54-
if r == nil {
53+
func (a *RPCAddress) String() string {
54+
if a == nil {
5555
return "<nil>"
5656
}
57-
return fmt.Sprintf("RPCAddress(%s)", r.GetAddress())
57+
return fmt.Sprintf("RPCAddress(%s)", a.GetAddress())
5858
}
5959

60-
func (r *RPCAddress) GetIP() net.IP {
61-
return net.IPv4(byte(0xff&(r.address>>56)), byte(0xff&(r.address>>48)), byte(0xff&(r.address>>40)), byte(0xff&(r.address>>32)))
60+
func (a *RPCAddress) GetIP() net.IP {
61+
return net.IPv4(byte(0xff&(a.address>>56)), byte(0xff&(a.address>>48)), byte(0xff&(a.address>>40)), byte(0xff&(a.address>>32)))
6262
}
6363

64-
func (r *RPCAddress) GetPort() int {
65-
return int(0xffff & (r.address >> 16))
64+
func (a *RPCAddress) GetPort() int {
65+
return int(0xffff & (a.address >> 16))
6666
}
6767

68-
func (r *RPCAddress) GetAddress() string {
69-
return fmt.Sprintf("%s:%d", r.GetIP(), r.GetPort())
68+
func (a *RPCAddress) GetAddress() string {
69+
return fmt.Sprintf("%s:%d", a.GetIP(), a.GetPort())
7070
}
7171

72-
func (r *RPCAddress) GetRawAddress() int64 {
73-
return r.address
72+
func (a *RPCAddress) GetRawAddress() int64 {
73+
return a.address
74+
}
75+
76+
func (a *RPCAddress) Equal(other *RPCAddress) bool {
77+
return a.address == other.address
7478
}

go-client/idl/base/rpc_address_test.go

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,70 @@
2020
package base
2121

2222
import (
23+
"fmt"
2324
"net"
2425
"testing"
2526

2627
"github.com/stretchr/testify/assert"
2728
)
2829

30+
func testNewRPCAddress(t *testing.T, addrStr string) *RPCAddress {
31+
addr, err := net.ResolveTCPAddr("tcp", addrStr)
32+
assert.NoError(t, err)
33+
34+
rpcAddr := NewRPCAddress(addr.IP, addr.Port)
35+
assert.Equal(t, addrStr, rpcAddr.GetAddress())
36+
assert.True(t, rpcAddr.Equal(rpcAddr))
37+
38+
return rpcAddr
39+
}
40+
2941
func TestRPCAddress(t *testing.T) {
30-
testCases := []string{
42+
tests := []string{
3143
"127.0.0.1:8080",
3244
"192.168.0.1:123",
3345
"0.0.0.0:12345",
3446
}
3547

36-
for _, ts := range testCases {
37-
tcpAddrStr := ts
38-
addr, err := net.ResolveTCPAddr("tcp", tcpAddrStr)
39-
assert.NoError(t, err)
48+
runner := func(test string) func(t *testing.T) {
49+
return func(t *testing.T) {
50+
t.Parallel()
51+
52+
testNewRPCAddress(t, test)
53+
}
54+
}
55+
56+
for _, test := range tests {
57+
name := fmt.Sprintf("<%s>", test)
58+
t.Run(name, runner(test))
59+
}
60+
}
61+
62+
func TestRPCAddressEquality(t *testing.T) {
63+
tests := []struct {
64+
x string
65+
y string
66+
equal bool
67+
}{
68+
{"127.0.0.1:8080", "127.0.0.1:8080", true},
69+
{"127.0.0.1:8080", "192.168.0.1:8080", false},
70+
{"127.0.0.1:8080", "127.0.0.1:8081", false},
71+
}
72+
73+
runner := func(x string, y string, equal bool) func(t *testing.T) {
74+
return func(t *testing.T) {
75+
t.Parallel()
76+
77+
rpcAddrX := testNewRPCAddress(t, x)
78+
rpcAddrY := testNewRPCAddress(t, y)
79+
80+
assert.Equal(t, equal, rpcAddrX.Equal(rpcAddrY))
81+
assert.Equal(t, equal, rpcAddrY.Equal(rpcAddrX))
82+
}
83+
}
4084

41-
rpcAddr := NewRPCAddress(addr.IP, addr.Port)
42-
assert.Equal(t, rpcAddr.GetAddress(), tcpAddrStr)
85+
for _, test := range tests {
86+
name := fmt.Sprintf("<%s>-vs-<%s>", test.x, test.y)
87+
t.Run(name, runner(test.x, test.y, test.equal))
4388
}
4489
}

0 commit comments

Comments
 (0)