Skip to content

Commit 246ded3

Browse files
committed
[common] Add unit test for convertions
1 parent 2d4612d commit 246ded3

File tree

4 files changed

+160
-53
lines changed

4 files changed

+160
-53
lines changed

internal/common/convert.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,55 +38,60 @@ func Int64Ceil(v float64) int64 {
3838

3939
// Int32Ptr makes a copy and returns the pointer to an int32.
4040
func Int32Ptr(v int32) *int32 {
41-
return &v
41+
return PtrOf(v)
4242
}
4343

4444
// Float64Ptr makes a copy and returns the pointer to a float64.
4545
func Float64Ptr(v float64) *float64 {
46-
return &v
46+
return PtrOf(v)
4747
}
4848

4949
// Int64Ptr makes a copy and returns the pointer to an int64.
5050
func Int64Ptr(v int64) *int64 {
51-
return &v
51+
return PtrOf(v)
5252
}
5353

5454
// StringPtr makes a copy and returns the pointer to a string.
5555
func StringPtr(v string) *string {
56-
return &v
56+
return PtrOf(v)
5757
}
5858

5959
// BoolPtr makes a copy and returns the pointer to a string.
6060
func BoolPtr(v bool) *bool {
61-
return &v
61+
return PtrOf(v)
6262
}
6363

6464
// TaskListPtr makes a copy and returns the pointer to a TaskList.
6565
func TaskListPtr(v s.TaskList) *s.TaskList {
66-
return &v
66+
return PtrOf(v)
6767
}
6868

6969
// DecisionTypePtr makes a copy and returns the pointer to a DecisionType.
7070
func DecisionTypePtr(t s.DecisionType) *s.DecisionType {
71-
return &t
71+
return PtrOf(t)
7272
}
7373

7474
// EventTypePtr makes a copy and returns the pointer to a EventType.
7575
func EventTypePtr(t s.EventType) *s.EventType {
76-
return &t
76+
return PtrOf(t)
7777
}
7878

7979
// QueryTaskCompletedTypePtr makes a copy and returns the pointer to a QueryTaskCompletedType.
8080
func QueryTaskCompletedTypePtr(t s.QueryTaskCompletedType) *s.QueryTaskCompletedType {
81-
return &t
81+
return PtrOf(t)
8282
}
8383

8484
// TaskListKindPtr makes a copy and returns the pointer to a TaskListKind.
8585
func TaskListKindPtr(t s.TaskListKind) *s.TaskListKind {
86-
return &t
86+
return PtrOf(t)
8787
}
8888

8989
// QueryResultTypePtr makes a copy and returns the pointer to a QueryResultType.
9090
func QueryResultTypePtr(t s.QueryResultType) *s.QueryResultType {
91-
return &t
91+
return PtrOf(t)
92+
}
93+
94+
// PtrOf makes a copy and returns the pointer to a value.
95+
func PtrOf[T any](v T) *T {
96+
return &v
9297
}

internal/common/convert_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package common
2+
3+
import (
4+
s "go.uber.org/cadence/.gen/go/shared"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestPtrOf(t *testing.T) {
11+
assert.Equal(t, "a", *PtrOf("a"))
12+
assert.Equal(t, 1, *PtrOf(1))
13+
assert.Equal(t, int32(1), *PtrOf(int32(1)))
14+
assert.Equal(t, int64(1), *PtrOf(int64(1)))
15+
assert.Equal(t, float64(1.1), *PtrOf(float64(1.1)))
16+
assert.Equal(t, true, *PtrOf(true))
17+
}
18+
19+
func TestPtrHelpers(t *testing.T) {
20+
assert.Equal(t, int32(1), *Int32Ptr(1))
21+
assert.Equal(t, int64(1), *Int64Ptr(1))
22+
assert.Equal(t, float64(1.1), *Float64Ptr(1.1))
23+
assert.Equal(t, true, *BoolPtr(true))
24+
assert.Equal(t, "a", *StringPtr("a"))
25+
assert.Equal(t, s.TaskList{Name: PtrOf("a")}, *TaskListPtr(s.TaskList{Name: PtrOf("a")}))
26+
assert.Equal(t, s.DecisionTypeScheduleActivityTask, *DecisionTypePtr(s.DecisionTypeScheduleActivityTask))
27+
assert.Equal(t, s.EventTypeWorkflowExecutionStarted, *EventTypePtr(s.EventTypeWorkflowExecutionStarted))
28+
assert.Equal(t, s.QueryTaskCompletedTypeCompleted, *QueryTaskCompletedTypePtr(s.QueryTaskCompletedTypeCompleted))
29+
assert.Equal(t, s.TaskListKindNormal, *TaskListKindPtr(s.TaskListKindNormal))
30+
assert.Equal(t, s.QueryResultTypeFailed, *QueryResultTypePtr(s.QueryResultTypeFailed))
31+
}
32+
33+
func TestCeilHelpers(t *testing.T) {
34+
assert.Equal(t, int32(2), Int32Ceil(1.1))
35+
assert.Equal(t, int64(2), Int64Ceil(1.1))
36+
}

internal/common/thrift_util.go

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,13 @@ import (
2727
"github.com/apache/thrift/lib/go/thrift"
2828
)
2929

30-
// TSerialize is used to serialize thrift TStruct to []byte
31-
func TSerialize(ctx context.Context, t thrift.TStruct) (b []byte, err error) {
32-
return thrift.NewTSerializer().Write(ctx, t)
33-
}
34-
3530
// TListSerialize is used to serialize list of thrift TStruct to []byte
36-
func TListSerialize(ts []thrift.TStruct) (b []byte, err error) {
31+
func TListSerialize(ts []thrift.TStruct) ([]byte, error) {
3732
if ts == nil {
38-
return
33+
return nil, nil
3934
}
4035

4136
t := thrift.NewTSerializer()
42-
t.Transport.Reset()
4337

4438
// NOTE: we don't write any markers as thrift by design being a streaming protocol doesn't
4539
// recommend writing length.
@@ -48,26 +42,11 @@ func TListSerialize(ts []thrift.TStruct) (b []byte, err error) {
4842
ctx := context.Background()
4943
for _, v := range ts {
5044
if e := v.Write(ctx, t.Protocol); e != nil {
51-
err = thrift.PrependError("error writing TStruct: ", e)
52-
return
45+
return nil, thrift.PrependError("error writing TStruct: ", e)
5346
}
5447
}
5548

56-
if err = t.Protocol.Flush(ctx); err != nil {
57-
return
58-
}
59-
60-
if err = t.Transport.Flush(ctx); err != nil {
61-
return
62-
}
63-
64-
b = t.Transport.Bytes()
65-
return
66-
}
67-
68-
// TDeserialize is used to deserialize []byte to thrift TStruct
69-
func TDeserialize(ctx context.Context, t thrift.TStruct, b []byte) (err error) {
70-
return thrift.NewTDeserializer().Read(ctx, t, b)
49+
return t.Transport.Bytes(), t.Protocol.Flush(ctx)
7150
}
7251

7352
// TListDeserialize is used to deserialize []byte to list of thrift TStruct
@@ -94,13 +73,8 @@ func TListDeserialize(ts []thrift.TStruct, b []byte) (err error) {
9473
func IsUseThriftEncoding(objs []interface{}) bool {
9574
// NOTE: our criteria to use which encoder is simple if all the types are serializable using thrift then we use
9675
// thrift encoder. For everything else we default to gob.
97-
98-
if len(objs) == 0 {
99-
return false
100-
}
101-
102-
for i := 0; i < len(objs); i++ {
103-
if !IsThriftType(objs[i]) {
76+
for _, obj := range objs {
77+
if !IsThriftType(obj) {
10478
return false
10579
}
10680
}
@@ -111,14 +85,9 @@ func IsUseThriftEncoding(objs []interface{}) bool {
11185
func IsUseThriftDecoding(objs []interface{}) bool {
11286
// NOTE: our criteria to use which encoder is simple if all the types are de-serializable using thrift then we use
11387
// thrift decoder. For everything else we default to gob.
114-
115-
if len(objs) == 0 {
116-
return false
117-
}
118-
119-
for i := 0; i < len(objs); i++ {
120-
rVal := reflect.ValueOf(objs[i])
121-
if rVal.Kind() != reflect.Ptr || !IsThriftType(reflect.Indirect(rVal).Interface()) {
88+
for _, obj := range objs {
89+
rVal := reflect.ValueOf(obj)
90+
if rVal.Kind() != reflect.Ptr || !IsThriftType(obj) {
12291
return false
12392
}
12493
}
@@ -133,6 +102,7 @@ func IsThriftType(v interface{}) bool {
133102
if reflect.ValueOf(v).Kind() != reflect.Ptr {
134103
return false
135104
}
136-
t := reflect.TypeOf((*thrift.TStruct)(nil)).Elem()
137-
return reflect.TypeOf(v).Implements(t)
105+
return reflect.TypeOf(v).Implements(tStructType)
138106
}
107+
108+
var tStructType = reflect.TypeOf((*thrift.TStruct)(nil)).Elem()
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package common
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/apache/thrift/lib/go/thrift"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestTListSerialize(t *testing.T) {
12+
t.Run("nil", func(t *testing.T) {
13+
data, err := TListSerialize(nil)
14+
assert.NoError(t, err)
15+
assert.Nil(t, data)
16+
})
17+
t.Run("normal", func(t *testing.T) {
18+
ts := []thrift.TStruct{
19+
&mockThriftStruct{Field1: "value1", Field2: 1},
20+
&mockThriftStruct{Field1: "value2", Field2: 2},
21+
}
22+
23+
_, err := TListSerialize(ts)
24+
assert.NoError(t, err)
25+
})
26+
}
27+
28+
func TestTListDeserialize(t *testing.T) {
29+
ts := []thrift.TStruct{
30+
&mockThriftStruct{},
31+
&mockThriftStruct{},
32+
}
33+
34+
data, err := TListSerialize(ts)
35+
assert.NoError(t, err)
36+
37+
err = TListDeserialize(ts, data)
38+
assert.NoError(t, err)
39+
}
40+
41+
func TestIsUseThriftEncoding(t *testing.T) {
42+
ts := []interface{}{
43+
&mockThriftStruct{},
44+
&mockThriftStruct{},
45+
}
46+
47+
result := IsUseThriftEncoding(ts)
48+
assert.True(t, result)
49+
50+
ts = []interface{}{
51+
&mockThriftStruct{},
52+
"string",
53+
}
54+
55+
result = IsUseThriftEncoding(ts)
56+
assert.False(t, result)
57+
}
58+
59+
func TestIsUseThriftDecoding(t *testing.T) {
60+
ts := []interface{}{
61+
&mockThriftStruct{},
62+
&mockThriftStruct{},
63+
}
64+
65+
assert.True(t, IsUseThriftDecoding(ts))
66+
67+
ts = []interface{}{
68+
&mockThriftStruct{},
69+
"string",
70+
}
71+
72+
assert.False(t, IsUseThriftDecoding(ts))
73+
}
74+
75+
func TestIsThriftType(t *testing.T) {
76+
assert.True(t, IsThriftType(&mockThriftStruct{}))
77+
78+
assert.False(t, IsThriftType(mockThriftStruct{}))
79+
}
80+
81+
type mockThriftStruct struct {
82+
Field1 string
83+
Field2 int
84+
}
85+
86+
func (m *mockThriftStruct) Read(ctx context.Context, iprot thrift.TProtocol) error {
87+
return nil
88+
}
89+
90+
func (m *mockThriftStruct) Write(ctx context.Context, oprot thrift.TProtocol) error {
91+
return nil
92+
}
93+
94+
func (m *mockThriftStruct) String() string {
95+
return ""
96+
}

0 commit comments

Comments
 (0)