Skip to content

Commit 0a85cf9

Browse files
refactor:
- aggregator: Support hooks. - callback: Enable field hook by default, support before* and after* hooks. - collection: Add fields field to Collection struct and pass it to finder, creator, updater, deleter, and aggregator objects. - field: Add field package to store metadata of the struct bound to the Collection. - hook: Refactor field hook, remove other hooks. - operation: Add before* and after* hook types. - model: Add mongox tag to ID field in the Model struct.
1 parent da90b0c commit 0a85cf9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2214
-2752
lines changed

aggregator/aggregator.go

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ package aggregator
1616

1717
import (
1818
"context"
19+
"time"
1920

21+
"github.com/chenmingyong0423/go-mongox/v2/operation"
22+
23+
"github.com/chenmingyong0423/go-mongox/v2/callback"
24+
"github.com/chenmingyong0423/go-mongox/v2/field"
2025
"go.mongodb.org/mongo-driver/v2/mongo"
2126
"go.mongodb.org/mongo-driver/v2/mongo/options"
2227
)
@@ -32,20 +37,52 @@ var _ IAggregator[any] = (*Aggregator[any])(nil)
3237
type Aggregator[T any] struct {
3338
collection *mongo.Collection
3439
pipeline any
40+
41+
dbCallbacks *callback.Callback
42+
fields []*field.Filed
43+
44+
modelHook any
45+
beforeHooks []beforeHookFn
46+
afterHooks []afterHookFn
3547
}
3648

37-
func NewAggregator[T any](collection *mongo.Collection) *Aggregator[T] {
49+
func NewAggregator[T any](collection *mongo.Collection, dbCallbacks *callback.Callback, fields []*field.Filed) *Aggregator[T] {
3850
return &Aggregator[T]{
39-
collection: collection,
51+
collection: collection,
52+
dbCallbacks: dbCallbacks,
53+
fields: fields,
4054
}
4155
}
4256

57+
func (a *Aggregator[T]) ModelHook(modelHook any) *Aggregator[T] {
58+
a.modelHook = modelHook
59+
return a
60+
}
61+
func (a *Aggregator[T]) RegisterBeforeHooks(hooks ...beforeHookFn) *Aggregator[T] {
62+
a.beforeHooks = append(a.beforeHooks, hooks...)
63+
return a
64+
}
65+
66+
func (a *Aggregator[T]) RegisterAfterHooks(hooks ...afterHookFn) *Aggregator[T] {
67+
a.afterHooks = append(a.afterHooks, hooks...)
68+
return a
69+
}
70+
4371
func (a *Aggregator[T]) Pipeline(pipeline any) *Aggregator[T] {
4472
a.pipeline = pipeline
4573
return a
4674
}
4775

4876
func (a *Aggregator[T]) Aggregate(ctx context.Context, opts ...options.Lister[options.AggregateOptions]) ([]*T, error) {
77+
currentTime := time.Now()
78+
globalOpContext := operation.NewOpContext(a.collection, operation.WithPipeline(a.pipeline), operation.WithMongoOptions(opts), operation.WithModelHook(a.modelHook), operation.WithStartTime(currentTime), operation.WithFields(a.fields))
79+
opContext := NewOpContext(a.collection, a.pipeline, WithMongoOptions(opts), WithModelHook(a.modelHook), WithStartTime(currentTime), WithFields(a.fields))
80+
81+
err := a.preActionHandler(ctx, globalOpContext, opContext, operation.OpTypeBeforeInsert)
82+
if err != nil {
83+
return nil, err
84+
}
85+
4986
cursor, err := a.collection.Aggregate(ctx, a.pipeline, opts...)
5087
if err != nil {
5188
return nil, err
@@ -57,12 +94,30 @@ func (a *Aggregator[T]) Aggregate(ctx context.Context, opts ...options.Lister[op
5794
if err != nil {
5895
return nil, err
5996
}
97+
98+
globalOpContext.Result = cursor
99+
opContext.Result = cursor
100+
err = a.postActionHandler(ctx, globalOpContext, opContext, operation.OpTypeAfterInsert)
101+
if err != nil {
102+
return nil, err
103+
}
104+
60105
return result, nil
61106
}
62107

63108
// AggregateWithParse is used to parse the result of the aggregation
64109
// result must be a pointer to a slice
65110
func (a *Aggregator[T]) AggregateWithParse(ctx context.Context, result any, opts ...options.Lister[options.AggregateOptions]) error {
111+
112+
currentTime := time.Now()
113+
globalOpContext := operation.NewOpContext(a.collection, operation.WithPipeline(a.pipeline), operation.WithMongoOptions(opts), operation.WithModelHook(a.modelHook), operation.WithStartTime(currentTime), operation.WithFields(a.fields))
114+
opContext := NewOpContext(a.collection, a.pipeline, WithMongoOptions(opts), WithModelHook(a.modelHook), WithStartTime(currentTime), WithFields(a.fields))
115+
116+
err := a.preActionHandler(ctx, globalOpContext, opContext, operation.OpTypeBeforeInsert)
117+
if err != nil {
118+
return err
119+
}
120+
66121
cursor, err := a.collection.Aggregate(ctx, a.pipeline, opts...)
67122
if err != nil {
68123
return err
@@ -72,5 +127,41 @@ func (a *Aggregator[T]) AggregateWithParse(ctx context.Context, result any, opts
72127
if err != nil {
73128
return err
74129
}
130+
131+
globalOpContext.Result = cursor
132+
opContext.Result = cursor
133+
err = a.postActionHandler(ctx, globalOpContext, opContext, operation.OpTypeAfterInsert)
134+
if err != nil {
135+
return err
136+
}
137+
138+
return nil
139+
}
140+
141+
func (a *Aggregator[T]) preActionHandler(ctx context.Context, globalOpContext *operation.OpContext, opContext *OpContext, opType operation.OpType) error {
142+
err := a.dbCallbacks.Execute(ctx, globalOpContext, opType)
143+
if err != nil {
144+
return err
145+
}
146+
for _, beforeHook := range a.beforeHooks {
147+
err = beforeHook(ctx, opContext)
148+
if err != nil {
149+
return err
150+
}
151+
}
152+
return nil
153+
}
154+
155+
func (a *Aggregator[T]) postActionHandler(ctx context.Context, globalOpContext *operation.OpContext, opContext *OpContext, opType operation.OpType) error {
156+
err := a.dbCallbacks.Execute(ctx, globalOpContext, opType)
157+
if err != nil {
158+
return err
159+
}
160+
for _, afterHook := range a.afterHooks {
161+
err = afterHook(ctx, opContext)
162+
if err != nil {
163+
return err
164+
}
165+
}
75166
return nil
76167
}

aggregator/aggregator_e2e_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ import (
2020
"context"
2121
"testing"
2222

23+
"github.com/chenmingyong0423/go-mongox/v2/callback"
24+
"github.com/chenmingyong0423/go-mongox/v2/field"
25+
2326
"github.com/chenmingyong0423/go-mongox/v2/bsonx"
2427
"github.com/stretchr/testify/require"
2528

@@ -47,14 +50,14 @@ func getCollection(t *testing.T) *mongo.Collection {
4750
func TestAggregator_e2e_New(t *testing.T) {
4851
collection := getCollection(t)
4952

50-
result := NewAggregator[TestUser](collection)
53+
result := NewAggregator[TestUser](collection, nil, nil)
5154
require.NotNil(t, result, "Expected non-nil Aggregator")
5255
require.Equal(t, collection, result.collection, "Expected collection field to be initialized correctly")
5356
}
5457

5558
func TestAggregator_e2e_Aggregation(t *testing.T) {
5659
collection := getCollection(t)
57-
aggregator := NewAggregator[TestUser](collection)
60+
aggregator := NewAggregator[TestUser](collection, callback.InitializeCallbacks(), field.ParseFields(TestUser{}))
5861

5962
testCases := []struct {
6063
name string
@@ -223,7 +226,7 @@ func TestAggregator_e2e_Aggregation(t *testing.T) {
223226

224227
func TestAggregator_e2e_AggregateWithParse(t *testing.T) {
225228
collection := getCollection(t)
226-
aggregator := NewAggregator[TestUser](collection)
229+
aggregator := NewAggregator[TestUser](collection, callback.InitializeCallbacks(), field.ParseFields(TestUser{}))
227230

228231
type User struct {
229232
Id string `bson:"_id"`

aggregator/aggregator_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ type UserName struct {
7171

7272
func TestAggregator_New(t *testing.T) {
7373
mongoCollection := &mongo.Collection{}
74-
aggregator := NewAggregator[any](mongoCollection)
74+
aggregator := NewAggregator[any](mongoCollection, nil, nil)
7575

7676
assert.NotNil(t, aggregator)
7777
assert.Equal(t, mongoCollection, aggregator.collection)

aggregator/types.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Generated by [optioner] command-line tool; DO NOT EDIT
2+
// If you have any questions, please create issues and submit contributions at:
3+
// https://github.com/chenmingyong0423/go-optioner
4+
5+
// Copyright 2025 chenmingyong0423
6+
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
19+
package aggregator
20+
21+
import (
22+
"context"
23+
"time"
24+
25+
"github.com/chenmingyong0423/go-mongox/v2/field"
26+
"go.mongodb.org/mongo-driver/v2/mongo"
27+
)
28+
29+
//go:generate optioner -type OpContext -output types.go -mode append
30+
type OpContext struct {
31+
Col *mongo.Collection `opt:"-"`
32+
Pipeline any `opt:"-"`
33+
34+
Fields []*field.Filed
35+
36+
MongoOptions any
37+
ModelHook any
38+
StartTime time.Time
39+
40+
Result any
41+
}
42+
43+
type (
44+
beforeHookFn func(ctx context.Context, opContext *OpContext, opts ...any) error
45+
afterHookFn func(ctx context.Context, opContext *OpContext, opts ...any) error
46+
)
47+
48+
type OpContextOption func(*OpContext)
49+
50+
func NewOpContext(col *mongo.Collection, pipeline any, opts ...OpContextOption) *OpContext {
51+
opContext := &OpContext{
52+
Col: col,
53+
Pipeline: pipeline,
54+
}
55+
56+
for _, opt := range opts {
57+
opt(opContext)
58+
}
59+
60+
return opContext
61+
}
62+
63+
func WithFields(fields []*field.Filed) OpContextOption {
64+
return func(opContext *OpContext) {
65+
opContext.Fields = fields
66+
}
67+
}
68+
69+
func WithMongoOptions(mongoOptions any) OpContextOption {
70+
return func(opContext *OpContext) {
71+
opContext.MongoOptions = mongoOptions
72+
}
73+
}
74+
75+
func WithModelHook(modelHook any) OpContextOption {
76+
return func(opContext *OpContext) {
77+
opContext.ModelHook = modelHook
78+
}
79+
}
80+
81+
func WithStartTime(startTime time.Time) OpContextOption {
82+
return func(opContext *OpContext) {
83+
opContext.StartTime = startTime
84+
}
85+
}
86+
87+
func WithResult(result any) OpContextOption {
88+
return func(opContext *OpContext) {
89+
opContext.Result = result
90+
}
91+
}

callback/callback.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,46 @@ package callback
1717
import (
1818
"context"
1919

20+
"github.com/chenmingyong0423/go-mongox/v2/internal/hook/field"
21+
2022
"github.com/chenmingyong0423/go-mongox/v2/operation"
2123
)
2224

2325
type CbFn func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error
2426

2527
func InitializeCallbacks() *Callback {
2628
return &Callback{
27-
beforeInsert: make([]callbackHandler, 0),
28-
afterInsert: make([]callbackHandler, 0),
29-
beforeUpdate: make([]callbackHandler, 0),
29+
beforeInsert: []callbackHandler{
30+
{
31+
name: "mongox:fieds",
32+
fn: func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
33+
return field.Execute(ctx, opCtx, operation.OpTypeBeforeInsert, opts...)
34+
},
35+
},
36+
},
37+
afterInsert: make([]callbackHandler, 0),
38+
beforeUpdate: []callbackHandler{
39+
{
40+
name: "mongox:fieds",
41+
fn: func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
42+
return field.Execute(ctx, opCtx, operation.OpTypeBeforeUpdate, opts...)
43+
},
44+
},
45+
},
3046
afterUpdate: make([]callbackHandler, 0),
3147
beforeDelete: make([]callbackHandler, 0),
3248
afterDelete: make([]callbackHandler, 0),
33-
beforeUpsert: make([]callbackHandler, 0),
34-
afterUpsert: make([]callbackHandler, 0),
35-
beforeFind: make([]callbackHandler, 0),
36-
afterFind: make([]callbackHandler, 0),
49+
beforeUpsert: []callbackHandler{
50+
{
51+
name: "mongox:fieds",
52+
fn: func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
53+
return field.Execute(ctx, opCtx, operation.OpTypeBeforeUpsert, opts...)
54+
},
55+
},
56+
},
57+
afterUpsert: make([]callbackHandler, 0),
58+
beforeFind: make([]callbackHandler, 0),
59+
afterFind: make([]callbackHandler, 0),
3760
}
3861
}
3962

collection.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/chenmingyong0423/go-mongox/v2/callback"
2020
"github.com/chenmingyong0423/go-mongox/v2/creator"
2121
"github.com/chenmingyong0423/go-mongox/v2/deleter"
22+
"github.com/chenmingyong0423/go-mongox/v2/field"
2223
"github.com/chenmingyong0423/go-mongox/v2/finder"
2324
"github.com/chenmingyong0423/go-mongox/v2/updater"
2425
"go.mongodb.org/mongo-driver/v2/mongo"
@@ -27,8 +28,9 @@ import (
2728
func NewCollection[T any](db *Database, collection string) *Collection[T] {
2829
return &Collection[T]{
2930
db: db,
30-
collection: db.database().Collection(collection),
31+
collection: db.Database().Collection(collection),
3132
callbacks: db.callbacks,
33+
fields: field.ParseFields(new(T)),
3234
}
3335
}
3436

@@ -37,25 +39,27 @@ type Collection[T any] struct {
3739
collection *mongo.Collection
3840
// callbacks inherited from database
3941
callbacks *callback.Callback
42+
43+
fields []*field.Filed
4044
}
4145

4246
func (c *Collection[T]) Finder() *finder.Finder[T] {
43-
return finder.NewFinder[T](c.collection, c.callbacks)
47+
return finder.NewFinder[T](c.collection, c.callbacks, c.fields)
4448
}
4549

4650
func (c *Collection[T]) Creator() *creator.Creator[T] {
47-
return creator.NewCreator[T](c.collection, c.callbacks)
51+
return creator.NewCreator[T](c.collection, c.callbacks, c.fields)
4852
}
4953

5054
func (c *Collection[T]) Updater() *updater.Updater[T] {
51-
return updater.NewUpdater[T](c.collection, c.callbacks)
55+
return updater.NewUpdater[T](c.collection, c.callbacks, c.fields)
5256
}
5357

5458
func (c *Collection[T]) Deleter() *deleter.Deleter[T] {
55-
return deleter.NewDeleter[T](c.collection, c.callbacks)
59+
return deleter.NewDeleter[T](c.collection, c.callbacks, c.fields)
5660
}
5761
func (c *Collection[T]) Aggregator() *aggregator.Aggregator[T] {
58-
return aggregator.NewAggregator[T](c.collection)
62+
return aggregator.NewAggregator[T](c.collection, c.callbacks, c.fields)
5963
}
6064

6165
func (c *Collection[T]) Collection() *mongo.Collection {

0 commit comments

Comments
 (0)