Skip to content
33 changes: 20 additions & 13 deletions creator/creator.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ import (
type ICreator[T any] interface {
InsertOne(ctx context.Context, docs *T, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error)
InsertMany(ctx context.Context, docs []*T, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error)
ModelHook(modelHook any) ICreator[T]
RegisterAfterHooks(hooks ...HookFn[T]) ICreator[T]
RegisterBeforeHooks(hooks ...HookFn[T]) ICreator[T]
GetCollection() *mongo.Collection
}

var _ ICreator[any] = (*Creator[any])(nil)
Expand All @@ -44,45 +48,45 @@ type Creator[T any] struct {

modelHook any

dbCallbacks *callback.Callback
beforeHooks []hookFn[T]
afterHooks []hookFn[T]
DBCallbacks *callback.Callback
BeforeHooks []HookFn[T]
AfterHooks []HookFn[T]

fields []*field.Filed
}

func NewCreator[T any](collection *mongo.Collection, dbCallbacks *callback.Callback, fields []*field.Filed) *Creator[T] {
return &Creator[T]{
collection: collection,
dbCallbacks: dbCallbacks,
DBCallbacks: dbCallbacks,
fields: fields,
}
}

func (c *Creator[T]) ModelHook(modelHook any) *Creator[T] {
func (c *Creator[T]) ModelHook(modelHook any) ICreator[T] {
c.modelHook = modelHook
return c
}

// RegisterBeforeHooks is used to set the after hooks of the insert operation
// If you register the hook for InsertOne, the opContext.Docs will be nil
// If you register the hook for InsertMany, the opContext.Doc will be nil
func (c *Creator[T]) RegisterBeforeHooks(hooks ...hookFn[T]) *Creator[T] {
c.beforeHooks = append(c.beforeHooks, hooks...)
func (c *Creator[T]) RegisterBeforeHooks(hooks ...HookFn[T]) ICreator[T] {
c.BeforeHooks = append(c.BeforeHooks, hooks...)
return c
}

func (c *Creator[T]) RegisterAfterHooks(hooks ...hookFn[T]) *Creator[T] {
c.afterHooks = append(c.afterHooks, hooks...)
func (c *Creator[T]) RegisterAfterHooks(hooks ...HookFn[T]) ICreator[T] {
c.AfterHooks = append(c.AfterHooks, hooks...)
return c
}

func (c *Creator[T]) preActionHandler(ctx context.Context, globalOpContext *operation.OpContext, opContext *OpContext[T], opType operation.OpType) error {
err := c.dbCallbacks.Execute(ctx, globalOpContext, opType)
err := c.DBCallbacks.Execute(ctx, globalOpContext, opType)
if err != nil {
return err
}
for _, beforeHook := range c.beforeHooks {
for _, beforeHook := range c.BeforeHooks {
err = beforeHook(ctx, opContext)
if err != nil {
return err
Expand All @@ -92,11 +96,11 @@ func (c *Creator[T]) preActionHandler(ctx context.Context, globalOpContext *oper
}

func (c *Creator[T]) postActionHandler(ctx context.Context, globalOpContext *operation.OpContext, opContext *OpContext[T], opType operation.OpType) error {
err := c.dbCallbacks.Execute(ctx, globalOpContext, opType)
err := c.DBCallbacks.Execute(ctx, globalOpContext, opType)
if err != nil {
return err
}
for _, afterHook := range c.afterHooks {
for _, afterHook := range c.AfterHooks {
err = afterHook(ctx, opContext)
if err != nil {
return err
Expand Down Expand Up @@ -157,3 +161,6 @@ func (c *Creator[T]) InsertMany(ctx context.Context, docs []*T, opts ...options.
}
return result, nil
}
func (c *Creator[T]) GetCollection() *mongo.Collection {
return c.collection
}
71 changes: 36 additions & 35 deletions creator/creator_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

//go:build e2e

package creator
package creator_test

import (
"context"
"errors"
"testing"
"time"

xcreator "github.com/chenmingyong0423/go-mongox/v2/creator"
"github.com/chenmingyong0423/go-mongox/v2/field"

"github.com/chenmingyong0423/go-mongox/v2/callback"
Expand Down Expand Up @@ -70,7 +71,7 @@ func newCollection(t *testing.T) *mongo.Collection {

func TestCreator_e2e_One(t *testing.T) {
collection := newCollection(t)
creator := NewCreator[User](collection, callback.InitializeCallbacks(), field.ParseFields(User{}))
creator := xcreator.NewCreator[User](collection, callback.InitializeCallbacks(), field.ParseFields(User{}))

type globalHook struct {
opType operation.OpType
Expand All @@ -86,8 +87,8 @@ func TestCreator_e2e_One(t *testing.T) {
ctx context.Context
doc *User
globalHook []globalHook
beforeHook []hookFn[User]
afterHook []hookFn[User]
beforeHook []xcreator.HookFn[User]
afterHook []xcreator.HookFn[User]

wantError assert.ErrorAssertionFunc
}{
Expand Down Expand Up @@ -226,8 +227,8 @@ func TestCreator_e2e_One(t *testing.T) {
options.InsertOne().SetComment("test"),
},
doc: nil,
beforeHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
beforeHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
return errors.New("before hook error")
},
},
Expand All @@ -251,8 +252,8 @@ func TestCreator_e2e_One(t *testing.T) {
Name: "Mingyong Chen",
Age: 18,
},
afterHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
afterHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
return errors.New("after hook error")
},
},
Expand All @@ -276,16 +277,16 @@ func TestCreator_e2e_One(t *testing.T) {
Name: "Mingyong Chen",
Age: 18,
},
beforeHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
beforeHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
if opContext.Doc == nil {
return errors.New("before hook error")
}
return nil
},
},
afterHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
afterHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
if opContext == nil {
return errors.New("after hook error")
}
Expand Down Expand Up @@ -322,8 +323,8 @@ func TestCreator_e2e_One(t *testing.T) {
Name: "Mingyong Chen",
Age: 18,
},
afterHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
afterHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
user := opContext.Doc
if user == nil {
return errors.New("user is nil")
Expand Down Expand Up @@ -372,7 +373,7 @@ func TestCreator_e2e_One(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
tc.before(tc.ctx, t)
for _, hook := range tc.globalHook {
creator.dbCallbacks.Register(hook.opType, hook.name, hook.fn)
creator.DBCallbacks.Register(hook.opType, hook.name, hook.fn)
}
insertOneResult, err := creator.RegisterBeforeHooks(tc.beforeHook...).
RegisterAfterHooks(tc.afterHook...).InsertOne(tc.ctx, tc.doc, tc.opts...)
Expand All @@ -384,17 +385,17 @@ func TestCreator_e2e_One(t *testing.T) {
require.NotNil(t, insertOneResult.InsertedID)
}
for _, hook := range tc.globalHook {
creator.dbCallbacks.Remove(hook.opType, hook.name)
creator.DBCallbacks.Remove(hook.opType, hook.name)
}
creator.beforeHooks = nil
creator.afterHooks = nil
creator.BeforeHooks = nil
creator.AfterHooks = nil
})
}
}

func TestCreator_e2e_Many(t *testing.T) {
collection := newCollection(t)
creator := NewCreator[User](collection, callback.InitializeCallbacks(), field.ParseFields(User{}))
creator := xcreator.NewCreator[User](collection, callback.InitializeCallbacks(), field.ParseFields(User{}))

type globalHook struct {
opType operation.OpType
Expand All @@ -411,8 +412,8 @@ func TestCreator_e2e_Many(t *testing.T) {
docs []*User
opts []options.Lister[options.InsertManyOptions]
globalHook []globalHook
beforeHook []hookFn[User]
afterHook []hookFn[User]
beforeHook []xcreator.HookFn[User]
afterHook []xcreator.HookFn[User]

wantIdsLength int
wantError assert.ErrorAssertionFunc
Expand Down Expand Up @@ -566,8 +567,8 @@ func TestCreator_e2e_Many(t *testing.T) {
options.InsertMany().SetComment("test"),
},
docs: nil,
beforeHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
beforeHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
return errors.New("before hook error")
},
},
Expand Down Expand Up @@ -597,8 +598,8 @@ func TestCreator_e2e_Many(t *testing.T) {
Age: 19,
},
},
afterHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
afterHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
return errors.New("after hook error")
},
},
Expand Down Expand Up @@ -628,16 +629,16 @@ func TestCreator_e2e_Many(t *testing.T) {
Age: 19,
},
},
beforeHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
beforeHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
if len(opContext.Docs) != 2 {
return errors.New("before hook error")
}
return nil
},
},
afterHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
afterHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
if opContext == nil {
return errors.New("after hook error")
}
Expand Down Expand Up @@ -683,8 +684,8 @@ func TestCreator_e2e_Many(t *testing.T) {
Age: 18,
},
},
afterHook: []hookFn[User]{
func(ctx context.Context, opContext *OpContext[User], opts ...any) error {
afterHook: []xcreator.HookFn[User]{
func(ctx context.Context, opContext *xcreator.OpContext[User], opts ...any) error {
users := opContext.Docs
if users == nil {
return errors.New("users is nil")
Expand Down Expand Up @@ -736,7 +737,7 @@ func TestCreator_e2e_Many(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
tc.before(tc.ctx, t)
for _, hook := range tc.globalHook {
creator.dbCallbacks.Register(hook.opType, hook.name, hook.fn)
creator.DBCallbacks.Register(hook.opType, hook.name, hook.fn)
}
insertManyResult, err := creator.RegisterBeforeHooks(tc.beforeHook...).
RegisterAfterHooks(tc.afterHook...).InsertMany(tc.ctx, tc.docs, tc.opts...)
Expand All @@ -749,10 +750,10 @@ func TestCreator_e2e_Many(t *testing.T) {
require.Len(t, insertManyResult.InsertedIDs, tc.wantIdsLength)
}
for _, hook := range tc.globalHook {
creator.dbCallbacks.Remove(hook.opType, hook.name)
creator.DBCallbacks.Remove(hook.opType, hook.name)
}
creator.beforeHooks = nil
creator.afterHooks = nil
creator.BeforeHooks = nil
creator.AfterHooks = nil
})
}
}
19 changes: 10 additions & 9 deletions creator/creator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package creator
package creator_test

import (
"context"
"errors"
"testing"
"time"

creator "github.com/chenmingyong0423/go-mongox/v2/creator"
mocks "github.com/chenmingyong0423/go-mongox/v2/mock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -49,24 +50,24 @@ func (tu *TestUser) DefaultUpdatedAt() {

func TestNewCreator(t *testing.T) {
mongoCollection := &mongo.Collection{}
creator := NewCreator[any](mongoCollection, nil, nil)
creator := creator.NewCreator[any](mongoCollection, nil, nil)

assert.NotNil(t, creator)
assert.Equal(t, mongoCollection, creator.collection)
assert.Equal(t, mongoCollection, creator.GetCollection())
}

func TestCreator_One(t *testing.T) {
testCases := []struct {
name string
mock func(ctx context.Context, ctl *gomock.Controller, doc *TestUser) ICreator[TestUser]
mock func(ctx context.Context, ctl *gomock.Controller, doc *TestUser) creator.ICreator[TestUser]
ctx context.Context
doc *TestUser

wantErr error
}{
{
name: "nil doc",
mock: func(ctx context.Context, ctl *gomock.Controller, doc *TestUser) ICreator[TestUser] {
mock: func(ctx context.Context, ctl *gomock.Controller, doc *TestUser) creator.ICreator[TestUser] {
mockCollection := mocks.NewMockICreator[TestUser](ctl)
mockCollection.EXPECT().InsertOne(ctx, doc).Return(nil, errors.New("nil filter")).Times(1)
return mockCollection
Expand All @@ -77,7 +78,7 @@ func TestCreator_One(t *testing.T) {
},
{
name: "success",
mock: func(ctx context.Context, ctl *gomock.Controller, doc *TestUser) ICreator[TestUser] {
mock: func(ctx context.Context, ctl *gomock.Controller, doc *TestUser) creator.ICreator[TestUser] {
mockCollection := mocks.NewMockICreator[TestUser](ctl)
mockCollection.EXPECT().InsertOne(ctx, doc).Return(&mongo.InsertOneResult{InsertedID: "?"}, nil).Times(1)
return mockCollection
Expand Down Expand Up @@ -107,7 +108,7 @@ func TestCreator_One(t *testing.T) {
func TestCreator_Many(t *testing.T) {
testCases := []struct {
name string
mock func(ctx context.Context, ctl *gomock.Controller, docs []*TestUser) ICreator[TestUser]
mock func(ctx context.Context, ctl *gomock.Controller, docs []*TestUser) creator.ICreator[TestUser]
ctx context.Context
docs []*TestUser

Expand All @@ -116,7 +117,7 @@ func TestCreator_Many(t *testing.T) {
}{
{
name: "nil docs",
mock: func(ctx context.Context, ctl *gomock.Controller, docs []*TestUser) ICreator[TestUser] {
mock: func(ctx context.Context, ctl *gomock.Controller, docs []*TestUser) creator.ICreator[TestUser] {
mockCollection := mocks.NewMockICreator[TestUser](ctl)
mockCollection.EXPECT().InsertMany(ctx, docs).Return(nil, errors.New("nil docs")).Times(1)
return mockCollection
Expand All @@ -130,7 +131,7 @@ func TestCreator_Many(t *testing.T) {
},
{
name: "success",
mock: func(ctx context.Context, ctl *gomock.Controller, docs []*TestUser) ICreator[TestUser] {
mock: func(ctx context.Context, ctl *gomock.Controller, docs []*TestUser) creator.ICreator[TestUser] {
mockCollection := mocks.NewMockICreator[TestUser](ctl)
mockCollection.EXPECT().InsertMany(ctx, docs).Return(&mongo.InsertManyResult{InsertedIDs: make([]interface{}, 2)}, nil).Times(1)
return mockCollection
Expand Down
2 changes: 1 addition & 1 deletion creator/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type (

Result any
}
hookFn[T any] func(ctx context.Context, opContext *OpContext[T], opts ...any) error
HookFn[T any] func(ctx context.Context, opContext *OpContext[T], opts ...any) error
)

type OpContextOption[T any] func(*OpContext[T])
Expand Down
Loading
Loading