Skip to content

Commit f848c5d

Browse files
refactor(aggregator): AggregateWithCallback → AggregateWithParse
1 parent 4d98897 commit f848c5d

File tree

5 files changed

+41
-61
lines changed

5 files changed

+41
-61
lines changed

aggregator/aggregator.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@ package aggregator
1717
import (
1818
"context"
1919

20-
"github.com/chenmingyong0423/go-mongox/types"
2120
"go.mongodb.org/mongo-driver/mongo"
2221
"go.mongodb.org/mongo-driver/mongo/options"
2322
)
2423

2524
//go:generate mockgen -source=aggregator.go -destination=../mock/aggregator.mock.go -package=mocks
2625
type iAggregator[T any] interface {
2726
Aggregate(ctx context.Context, opts ...*options.AggregateOptions) ([]*T, error)
28-
AggregateWithCallback(ctx context.Context, handler types.ResultHandler, opts ...*options.AggregateOptions) error
27+
AggregateWithParse(ctx context.Context, result any, opts ...*options.AggregateOptions) error
2928
}
3029

3130
type Aggregator[T any] struct {
@@ -59,13 +58,16 @@ func (a *Aggregator[T]) Aggregate(ctx context.Context, opts ...*options.Aggregat
5958
return result, nil
6059
}
6160

62-
func (a *Aggregator[T]) AggregateWithCallback(ctx context.Context, handler types.ResultHandler, opts ...*options.AggregateOptions) error {
61+
// AggregateWithParse is used to parse the result of the aggregation
62+
// result must be a pointer to a slice
63+
func (a *Aggregator[T]) AggregateWithParse(ctx context.Context, result any, opts ...*options.AggregateOptions) error {
6364
cursor, err := a.collection.Aggregate(ctx, a.pipeline, opts...)
6465
if err != nil {
6566
return err
6667
}
6768
defer cursor.Close(ctx)
68-
if err = handler(ctx, cursor); err != nil {
69+
err = cursor.All(ctx, result)
70+
if err != nil {
6971
return err
7072
}
7173
return nil

aggregator/aggregator_e2e_test.go

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package aggregator
1818

1919
import (
2020
"context"
21-
"errors"
2221
"testing"
2322

2423
"github.com/chenmingyong0423/go-mongox/bsonx"
@@ -191,7 +190,7 @@ func TestAggregator_e2e_Aggregation(t *testing.T) {
191190
}
192191
}
193192

194-
func TestAggregator_e2e_AggregationWithCallback(t *testing.T) {
193+
func TestAggregator_e2e_AggregateWithParse(t *testing.T) {
195194
collection := getCollection(t)
196195
aggregator := NewAggregator[types.TestUser](collection)
197196

@@ -209,8 +208,7 @@ func TestAggregator_e2e_AggregationWithCallback(t *testing.T) {
209208
pipeline any
210209
aggregationOptions *options.AggregateOptions
211210
ctx context.Context
212-
preUsers []*User
213-
callback types.ResultHandler
211+
result any
214212
want []*User
215213
wantErr assert.ErrorAssertionFunc
216214
}{
@@ -241,7 +239,7 @@ func TestAggregator_e2e_AggregationWithCallback(t *testing.T) {
241239
assert.Equal(t, int64(2), deleteResult.DeletedCount)
242240
},
243241
pipeline: aggregation.StageBsonBuilder().Set(bsonx.M("is_programmer", true)).Build(),
244-
preUsers: make([]*User, 0, 4),
242+
result: make([]*User, 0, 4),
245243
want: []*User{
246244
{Id: "1", Name: "cmy", Age: 24, IsProgrammer: true},
247245
{Id: "2", Name: "gopher", Age: 20, IsProgrammer: true},
@@ -265,7 +263,7 @@ func TestAggregator_e2e_AggregationWithCallback(t *testing.T) {
265263
assert.Equal(t, int64(2), deleteResult.DeletedCount)
266264
},
267265
pipeline: aggregation.StageBsonBuilder().Set(bsonx.M("is_programmer", true)).Sort(bsonx.M("name", 1)).Build(),
268-
preUsers: make([]*User, 0, 4),
266+
result: make([]*User, 0, 4),
269267
want: []*User{
270268
{Id: "1", Name: "cmy", Age: 24, IsProgrammer: true},
271269
{Id: "2", Name: "gopher", Age: 20, IsProgrammer: true},
@@ -289,11 +287,8 @@ func TestAggregator_e2e_AggregationWithCallback(t *testing.T) {
289287
assert.NoError(t, err)
290288
assert.Equal(t, int64(2), deleteResult.DeletedCount)
291289
},
292-
pipeline: aggregation.StageBsonBuilder().Set(bsonx.M("is_programmer", true)).Sort(bsonx.M("name", 1)).Build(),
293-
preUsers: make([]*User, 0),
294-
callback: func(ctx context.Context, cursor *mongo.Cursor) error {
295-
return errors.New("got error from cursor")
296-
},
290+
pipeline: aggregation.StageBsonBuilder().Set(bsonx.M("is_programmer", true)).Sort(bsonx.M("name", 1)).Build(),
291+
result: []string{},
297292
want: []*User{},
298293
aggregationOptions: options.Aggregate().SetCollation(&options.Collation{Locale: "en", Strength: 2}),
299294
ctx: context.Background(),
@@ -303,16 +298,10 @@ func TestAggregator_e2e_AggregationWithCallback(t *testing.T) {
303298
for _, tc := range testCases {
304299
t.Run(tc.name, func(t *testing.T) {
305300
tc.before(tc.ctx, t)
306-
callback := func(ctx context.Context, cursor *mongo.Cursor) error {
307-
return cursor.All(ctx, &tc.preUsers)
308-
}
309-
if tc.callback != nil {
310-
callback = tc.callback
311-
}
312-
err := aggregator.Pipeline(tc.pipeline).AggregateWithCallback(tc.ctx, callback, tc.aggregationOptions)
301+
err := aggregator.Pipeline(tc.pipeline).AggregateWithParse(tc.ctx, &tc.result, tc.aggregationOptions)
313302
tc.after(tc.ctx, t)
314303
if tc.wantErr(t, err) {
315-
assert.ElementsMatch(t, tc.want, tc.preUsers)
304+
assert.ElementsMatch(t, tc.want, tc.result)
316305
}
317306
})
318307
}

aggregator/aggregator_test.go

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,41 +84,41 @@ func TestAggregator_Aggregation(t *testing.T) {
8484
}
8585
}
8686

87-
func TestAggregator_AggregationWithCallback(t *testing.T) {
87+
func TestAggregator_AggregateWithParse(t *testing.T) {
8888
type User struct {
8989
Id string `bson:"_id"`
9090
Name string `bson:"name"`
9191
Age int64
9292
IsProgrammer bool `bson:"is_programmer"`
9393
}
9494
testCases := []struct {
95-
name string
96-
mock func(ctx context.Context, ctl *gomock.Controller) iAggregator[types.TestUser]
97-
ctx context.Context
98-
callbackParam []*User
99-
want []*User
100-
wantErr assert.ErrorAssertionFunc
95+
name string
96+
mock func(ctx context.Context, ctl *gomock.Controller, result any) iAggregator[types.TestUser]
97+
ctx context.Context
98+
result []*User
99+
want []*User
100+
wantErr assert.ErrorAssertionFunc
101101
}{
102102
{
103103
name: "got error",
104-
mock: func(ctx context.Context, ctl *gomock.Controller) iAggregator[types.TestUser] {
104+
mock: func(ctx context.Context, ctl *gomock.Controller, result any) iAggregator[types.TestUser] {
105105
aggregator := mocks.NewMockiAggregator[types.TestUser](ctl)
106-
aggregator.EXPECT().AggregateWithCallback(ctx, gomock.Any()).Return(errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid")).Times(1)
106+
aggregator.EXPECT().AggregateWithParse(ctx, result).Return(errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid")).Times(1)
107107
return aggregator
108108
},
109-
ctx: context.Background(),
110-
callbackParam: []*User{},
111-
wantErr: assert.Error,
109+
ctx: context.Background(),
110+
result: []*User{},
111+
wantErr: assert.Error,
112112
},
113113
{
114114
name: "got result",
115-
mock: func(ctx context.Context, ctl *gomock.Controller) iAggregator[types.TestUser] {
115+
mock: func(ctx context.Context, ctl *gomock.Controller, result any) iAggregator[types.TestUser] {
116116
aggregator := mocks.NewMockiAggregator[types.TestUser](ctl)
117-
aggregator.EXPECT().AggregateWithCallback(ctx, gomock.Any()).Return(nil).Times(1)
117+
aggregator.EXPECT().AggregateWithParse(ctx, result).Return(nil).Times(1)
118118
return aggregator
119119
},
120120
ctx: context.Background(),
121-
callbackParam: []*User{
121+
result: []*User{
122122
{Id: "1", Name: "cmy", Age: 24, IsProgrammer: true},
123123
},
124124
want: []*User{
@@ -131,13 +131,11 @@ func TestAggregator_AggregationWithCallback(t *testing.T) {
131131
t.Run(tc.name, func(t *testing.T) {
132132
ctl := gomock.NewController(t)
133133
defer ctl.Finish()
134-
var callback types.ResultHandler = func(ctx context.Context, cursor *mongo.Cursor) error {
135-
return cursor.All(ctx, &tc.callbackParam)
136-
}
137-
aggregator := tc.mock(tc.ctx, ctl)
138-
err := aggregator.AggregateWithCallback(tc.ctx, callback)
134+
135+
aggregator := tc.mock(tc.ctx, ctl, tc.result)
136+
err := aggregator.AggregateWithParse(tc.ctx, tc.result)
139137
if tc.wantErr(t, err) {
140-
assert.ElementsMatch(t, tc.want, tc.callbackParam)
138+
assert.ElementsMatch(t, tc.want, tc.result)
141139
}
142140
})
143141
}

mock/aggregator.mock.go

Lines changed: 8 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

types/types.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414

1515
package types
1616

17-
import (
18-
"context"
19-
20-
"go.mongodb.org/mongo-driver/mongo"
21-
)
22-
2317
const (
2418
Id = "_id"
2519
Set = "$set"
@@ -208,8 +202,6 @@ type TextOptions struct {
208202
DiacriticSensitive bool
209203
}
210204

211-
type ResultHandler func(ctx context.Context, cursor *mongo.Cursor) error
212-
213205
type Numeric interface {
214206
~int | ~int8 | ~int16 | ~int32 | ~int64 |
215207
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr |

0 commit comments

Comments
 (0)