Skip to content

Commit 99dd184

Browse files
authored
fix: correct OpType usage in Aggregator component (#107)
- Add OpTypeBeforeAggregate and OpTypeAfterAggregate constants - Update callback system to support aggregation operation types - Fix Aggregator methods to use correct OpTypes instead of insert types - Add comprehensive unit and E2E tests for aggregation hooks - Ensure insert hooks are not triggered during aggregation operations This fixes the critical bug where aggregation operations incorrectly triggered insert hooks, breaking the plugin system for aggregations. Signed-off-by: ramsyana <[email protected]>
1 parent ad76b10 commit 99dd184

File tree

5 files changed

+208
-26
lines changed

5 files changed

+208
-26
lines changed

aggregator/aggregator.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func (a *Aggregator[T]) Aggregate(ctx context.Context, opts ...options.Lister[op
7878
globalOpContext := operation.NewOpContext(a.collection, operation.WithPipeline(a.pipeline), operation.WithMongoOptions(opts), operation.WithModelHook(a.modelHook), operation.WithStartTime(currentTime), operation.WithFields(a.fields))
7979
opContext := NewOpContext(a.collection, a.pipeline, WithMongoOptions(opts), WithModelHook(a.modelHook), WithStartTime(currentTime), WithFields(a.fields))
8080

81-
err := a.preActionHandler(ctx, globalOpContext, opContext, operation.OpTypeBeforeInsert)
81+
err := a.preActionHandler(ctx, globalOpContext, opContext, operation.OpTypeBeforeAggregate)
8282
if err != nil {
8383
return nil, err
8484
}
@@ -99,7 +99,7 @@ func (a *Aggregator[T]) Aggregate(ctx context.Context, opts ...options.Lister[op
9999

100100
globalOpContext.Result = cursor
101101
opContext.Result = cursor
102-
err = a.postActionHandler(ctx, globalOpContext, opContext, operation.OpTypeAfterInsert)
102+
err = a.postActionHandler(ctx, globalOpContext, opContext, operation.OpTypeAfterAggregate)
103103
if err != nil {
104104
return nil, err
105105
}
@@ -115,7 +115,7 @@ func (a *Aggregator[T]) AggregateWithParse(ctx context.Context, result any, opts
115115
globalOpContext := operation.NewOpContext(a.collection, operation.WithPipeline(a.pipeline), operation.WithMongoOptions(opts), operation.WithModelHook(a.modelHook), operation.WithStartTime(currentTime), operation.WithFields(a.fields))
116116
opContext := NewOpContext(a.collection, a.pipeline, WithMongoOptions(opts), WithModelHook(a.modelHook), WithStartTime(currentTime), WithFields(a.fields))
117117

118-
err := a.preActionHandler(ctx, globalOpContext, opContext, operation.OpTypeBeforeInsert)
118+
err := a.preActionHandler(ctx, globalOpContext, opContext, operation.OpTypeBeforeAggregate)
119119
if err != nil {
120120
return err
121121
}
@@ -134,7 +134,7 @@ func (a *Aggregator[T]) AggregateWithParse(ctx context.Context, result any, opts
134134

135135
globalOpContext.Result = cursor
136136
opContext.Result = cursor
137-
err = a.postActionHandler(ctx, globalOpContext, opContext, operation.OpTypeAfterInsert)
137+
err = a.postActionHandler(ctx, globalOpContext, opContext, operation.OpTypeAfterAggregate)
138138
if err != nil {
139139
return err
140140
}

aggregator/aggregator_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ import (
2626
"go.mongodb.org/mongo-driver/v2/bson"
2727
"go.mongodb.org/mongo-driver/v2/mongo"
2828
"go.mongodb.org/mongo-driver/v2/mongo/options"
29+
30+
"github.com/chenmingyong0423/go-mongox/v2/callback"
31+
"github.com/chenmingyong0423/go-mongox/v2/operation"
2932
"go.uber.org/mock/gomock"
3033
)
3134

@@ -221,3 +224,85 @@ func TestAggregator_AggregateWithParse(t *testing.T) {
221224
})
222225
}
223226
}
227+
228+
func TestAggregator_CorrectOpTypes(t *testing.T) {
229+
t.Run("aggregation should use correct OpTypes", func(t *testing.T) {
230+
// Setup
231+
col := &mongo.Collection{}
232+
dbCallbacks := callback.InitializeCallbacks()
233+
234+
// Track which hooks are called
235+
var calledHooks []string
236+
237+
// Register hooks for different operation types
238+
dbCallbacks.Register(operation.OpTypeBeforeInsert, "insert-before", func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
239+
calledHooks = append(calledHooks, "insert-before")
240+
return nil
241+
})
242+
243+
dbCallbacks.Register(operation.OpTypeAfterInsert, "insert-after", func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
244+
calledHooks = append(calledHooks, "insert-after")
245+
return nil
246+
})
247+
248+
dbCallbacks.Register(operation.OpTypeBeforeAggregate, "aggregate-before", func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
249+
calledHooks = append(calledHooks, "aggregate-before")
250+
return nil
251+
})
252+
253+
dbCallbacks.Register(operation.OpTypeAfterAggregate, "aggregate-after", func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
254+
calledHooks = append(calledHooks, "aggregate-after")
255+
return nil
256+
})
257+
258+
// Create aggregator
259+
aggregator := NewAggregator[map[string]interface{}](col, dbCallbacks, nil)
260+
pipeline := []interface{}{map[string]interface{}{"$match": map[string]interface{}{}}}
261+
262+
// Test preActionHandler and postActionHandler directly
263+
ctx := context.Background()
264+
globalOpContext := operation.NewOpContext(col)
265+
opContext := NewOpContext(col, pipeline)
266+
267+
// Test before handler
268+
err := aggregator.preActionHandler(ctx, globalOpContext, opContext, operation.OpTypeBeforeAggregate)
269+
assert.NoError(t, err)
270+
271+
// Test after handler
272+
err = aggregator.postActionHandler(ctx, globalOpContext, opContext, operation.OpTypeAfterAggregate)
273+
assert.NoError(t, err)
274+
275+
// Verify only aggregation hooks were called
276+
assert.Contains(t, calledHooks, "aggregate-before")
277+
assert.Contains(t, calledHooks, "aggregate-after")
278+
assert.NotContains(t, calledHooks, "insert-before")
279+
assert.NotContains(t, calledHooks, "insert-after")
280+
})
281+
282+
t.Run("insert hooks should not be triggered by aggregation", func(t *testing.T) {
283+
// Setup
284+
col := &mongo.Collection{}
285+
dbCallbacks := callback.InitializeCallbacks()
286+
287+
// Register an insert hook that should NOT be called
288+
insertHookCalled := false
289+
dbCallbacks.Register(operation.OpTypeBeforeInsert, "should-not-be-called", func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
290+
insertHookCalled = true
291+
return errors.New("insert hook incorrectly called during aggregation")
292+
})
293+
294+
// Create aggregator
295+
aggregator := NewAggregator[map[string]interface{}](col, dbCallbacks, nil)
296+
pipeline := []interface{}{map[string]interface{}{"$match": map[string]interface{}{}}}
297+
298+
// Test that insert hooks are NOT called
299+
ctx := context.Background()
300+
globalOpContext := operation.NewOpContext(col)
301+
opContext := NewOpContext(col, pipeline)
302+
303+
// This should NOT trigger insert hooks
304+
err := aggregator.preActionHandler(ctx, globalOpContext, opContext, operation.OpTypeBeforeAggregate)
305+
assert.NoError(t, err)
306+
assert.False(t, insertHookCalled, "Insert hook should not be called during aggregation")
307+
})
308+
}

callback/callback.go

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,20 +126,38 @@ func InitializeCallbacks() *Callback {
126126
},
127127
},
128128
},
129+
beforeAggregate: []callbackHandler{
130+
{
131+
name: "mongox:model",
132+
fn: func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
133+
return model.Execute(ctx, opCtx, operation.OpTypeBeforeAggregate, opts...)
134+
},
135+
},
136+
},
137+
afterAggregate: []callbackHandler{
138+
{
139+
name: "mongox:model",
140+
fn: func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
141+
return model.Execute(ctx, opCtx, operation.OpTypeAfterAggregate, opts...)
142+
},
143+
},
144+
},
129145
}
130146
}
131147

132148
type Callback struct {
133-
beforeInsert []callbackHandler
134-
afterInsert []callbackHandler
135-
beforeUpdate []callbackHandler
136-
afterUpdate []callbackHandler
137-
beforeDelete []callbackHandler
138-
afterDelete []callbackHandler
139-
beforeUpsert []callbackHandler
140-
afterUpsert []callbackHandler
141-
beforeFind []callbackHandler
142-
afterFind []callbackHandler
149+
beforeInsert []callbackHandler
150+
afterInsert []callbackHandler
151+
beforeUpdate []callbackHandler
152+
afterUpdate []callbackHandler
153+
beforeDelete []callbackHandler
154+
afterDelete []callbackHandler
155+
beforeUpsert []callbackHandler
156+
afterUpsert []callbackHandler
157+
beforeFind []callbackHandler
158+
afterFind []callbackHandler
159+
beforeAggregate []callbackHandler
160+
afterAggregate []callbackHandler
143161
}
144162

145163
func (c *Callback) BeforeInsert() []callbackHandler {
@@ -182,6 +200,14 @@ func (c *Callback) AfterFind() []callbackHandler {
182200
return c.afterFind
183201
}
184202

203+
func (c *Callback) BeforeAggregate() []callbackHandler {
204+
return c.beforeAggregate
205+
}
206+
207+
func (c *Callback) AfterAggregate() []callbackHandler {
208+
return c.afterAggregate
209+
}
210+
185211
func (c *Callback) Execute(ctx context.Context, opCtx *operation.OpContext, opType operation.OpType, opts ...any) error {
186212
switch opType {
187213
case operation.OpTypeBeforeInsert:
@@ -204,6 +230,10 @@ func (c *Callback) Execute(ctx context.Context, opCtx *operation.OpContext, opTy
204230
return c.execute(ctx, opCtx, c.beforeFind, opts...)
205231
case operation.OpTypeAfterFind:
206232
return c.execute(ctx, opCtx, c.afterFind, opts...)
233+
case operation.OpTypeBeforeAggregate:
234+
return c.execute(ctx, opCtx, c.beforeAggregate, opts...)
235+
case operation.OpTypeAfterAggregate:
236+
return c.execute(ctx, opCtx, c.afterAggregate, opts...)
207237
}
208238
return nil
209239
}
@@ -269,6 +299,16 @@ func (c *Callback) Register(opType operation.OpType, name string, fn CbFn) {
269299
name: name,
270300
fn: fn,
271301
})
302+
case operation.OpTypeBeforeAggregate:
303+
c.beforeAggregate = append(c.beforeAggregate, callbackHandler{
304+
name: name,
305+
fn: fn,
306+
})
307+
case operation.OpTypeAfterAggregate:
308+
c.afterAggregate = append(c.afterAggregate, callbackHandler{
309+
name: name,
310+
fn: fn,
311+
})
272312
case operation.OpTypeBeforeAny:
273313
c.beforeInsert = append(c.beforeInsert, callbackHandler{
274314
name: name,
@@ -290,6 +330,10 @@ func (c *Callback) Register(opType operation.OpType, name string, fn CbFn) {
290330
name: name,
291331
fn: fn,
292332
})
333+
c.beforeAggregate = append(c.beforeAggregate, callbackHandler{
334+
name: name,
335+
fn: fn,
336+
})
293337
case operation.OpTypeAfterAny:
294338
c.afterInsert = append(c.afterInsert, callbackHandler{
295339
name: name,
@@ -311,6 +355,10 @@ func (c *Callback) Register(opType operation.OpType, name string, fn CbFn) {
311355
name: name,
312356
fn: fn,
313357
})
358+
c.afterAggregate = append(c.afterAggregate, callbackHandler{
359+
name: name,
360+
fn: fn,
361+
})
314362
}
315363
}
316364

@@ -336,18 +384,24 @@ func (c *Callback) Remove(opType operation.OpType, name string) {
336384
c.beforeFind = c.remove(c.beforeFind, name)
337385
case operation.OpTypeAfterFind:
338386
c.afterFind = c.remove(c.afterFind, name)
387+
case operation.OpTypeBeforeAggregate:
388+
c.beforeAggregate = c.remove(c.beforeAggregate, name)
389+
case operation.OpTypeAfterAggregate:
390+
c.afterAggregate = c.remove(c.afterAggregate, name)
339391
case operation.OpTypeBeforeAny:
340392
c.beforeInsert = c.remove(c.beforeInsert, name)
341393
c.beforeUpdate = c.remove(c.beforeUpdate, name)
342394
c.beforeDelete = c.remove(c.beforeDelete, name)
343395
c.beforeUpsert = c.remove(c.beforeUpsert, name)
344396
c.beforeFind = c.remove(c.beforeFind, name)
397+
c.beforeAggregate = c.remove(c.beforeAggregate, name)
345398
case operation.OpTypeAfterAny:
346399
c.afterInsert = c.remove(c.afterInsert, name)
347400
c.afterUpdate = c.remove(c.afterUpdate, name)
348401
c.afterDelete = c.remove(c.afterDelete, name)
349402
c.afterUpsert = c.remove(c.afterUpsert, name)
350403
c.afterFind = c.remove(c.afterFind, name)
404+
c.afterAggregate = c.remove(c.afterAggregate, name)
351405
}
352406
}
353407

database_e2e_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,44 @@ func TestRegisterPlugin_BeforeAny(t *testing.T) {
380380
assert.False(t, isCalled)
381381
})
382382
}
383+
384+
func TestRegisterPlugin_Aggregate(t *testing.T) {
385+
c := getMongoClient(t)
386+
387+
// Track which hooks are called
388+
var calledHooks []string
389+
390+
t.Run("aggregation should use correct OpTypes", func(t *testing.T) {
391+
db := newDatabase(NewClient(c, &Config{}), "db-test")
392+
393+
// Register hooks for different operations
394+
db.RegisterPlugin("before-insert", func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
395+
calledHooks = append(calledHooks, "before-insert")
396+
return nil
397+
}, operation.OpTypeBeforeInsert)
398+
399+
db.RegisterPlugin("before-aggregate", func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
400+
calledHooks = append(calledHooks, "before-aggregate")
401+
return nil
402+
}, operation.OpTypeBeforeAggregate)
403+
404+
db.RegisterPlugin("after-aggregate", func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
405+
calledHooks = append(calledHooks, "after-aggregate")
406+
return nil
407+
}, operation.OpTypeAfterAggregate)
408+
409+
// Test that aggregation hooks are called correctly
410+
err := db.callbacks.Execute(context.Background(), operation.NewOpContext(nil, operation.WithPipeline(bson.A{bson.M{"$match": bson.M{}}})), operation.OpTypeBeforeAggregate)
411+
require.Nil(t, err)
412+
413+
err = db.callbacks.Execute(context.Background(), operation.NewOpContext(nil, operation.WithPipeline(bson.A{bson.M{"$match": bson.M{}}})), operation.OpTypeAfterAggregate)
414+
require.Nil(t, err)
415+
416+
// Verify correct hooks were called
417+
assert.Contains(t, calledHooks, "before-aggregate")
418+
assert.Contains(t, calledHooks, "after-aggregate")
419+
420+
// Verify insert hooks were NOT called during aggregation operations
421+
assert.NotContains(t, calledHooks, "before-insert")
422+
})
423+
}

operation/operation_type.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,20 @@ import (
3030
type OpType string
3131

3232
const (
33-
OpTypeBeforeInsert OpType = "beforeInsert"
34-
OpTypeAfterInsert OpType = "afterInsert"
35-
OpTypeBeforeUpdate OpType = "beforeUpdate"
36-
OpTypeAfterUpdate OpType = "afterUpdate"
37-
OpTypeBeforeDelete OpType = "beforeDelete"
38-
OpTypeAfterDelete OpType = "afterDelete"
39-
OpTypeBeforeUpsert OpType = "beforeUpsert"
40-
OpTypeAfterUpsert OpType = "afterUpsert"
41-
OpTypeBeforeFind OpType = "beforeFind"
42-
OpTypeAfterFind OpType = "afterFind"
43-
OpTypeBeforeAny OpType = "before*"
44-
OpTypeAfterAny OpType = "after*"
33+
OpTypeBeforeInsert OpType = "beforeInsert"
34+
OpTypeAfterInsert OpType = "afterInsert"
35+
OpTypeBeforeUpdate OpType = "beforeUpdate"
36+
OpTypeAfterUpdate OpType = "afterUpdate"
37+
OpTypeBeforeDelete OpType = "beforeDelete"
38+
OpTypeAfterDelete OpType = "afterDelete"
39+
OpTypeBeforeUpsert OpType = "beforeUpsert"
40+
OpTypeAfterUpsert OpType = "afterUpsert"
41+
OpTypeBeforeFind OpType = "beforeFind"
42+
OpTypeAfterFind OpType = "afterFind"
43+
OpTypeBeforeAggregate OpType = "beforeAggregate"
44+
OpTypeAfterAggregate OpType = "afterAggregate"
45+
OpTypeBeforeAny OpType = "before*"
46+
OpTypeAfterAny OpType = "after*"
4547
)
4648

4749
//go:generate optioner -type OpContext -output operation_type.go -mode append

0 commit comments

Comments
 (0)