@@ -42,6 +42,8 @@ type Updater[T any] struct {
4242 filter any
4343 updates any
4444 replacement any
45+ beforeHooks []beforeHookFn
46+ afterHooks []afterHookFn
4547}
4648
4749// Filter is used to set the filter of the query
@@ -66,9 +68,47 @@ func (u *Updater[T]) UpdatesWithOperator(operator string, value any) *Updater[T]
6668 return u
6769}
6870
71+ func (u * Updater [T ]) RegisterBeforeHooks (hooks ... beforeHookFn ) * Updater [T ] {
72+ u .beforeHooks = append (u .beforeHooks , hooks ... )
73+ return u
74+ }
75+
76+ func (u * Updater [T ]) RegisterAfterHooks (hooks ... afterHookFn ) * Updater [T ] {
77+ u .afterHooks = append (u .afterHooks , hooks ... )
78+ return u
79+ }
80+
81+ func (u * Updater [T ]) preActionHandler (ctx context.Context , globalOpContext * operation.OpContext , opContext * BeforeOpContext , opType operation.OpType ) error {
82+ err := callback .GetCallback ().Execute (ctx , globalOpContext , opType )
83+ if err != nil {
84+ return err
85+ }
86+ for _ , beforeHook := range u .beforeHooks {
87+ err = beforeHook (ctx , opContext )
88+ if err != nil {
89+ return err
90+ }
91+ }
92+ return nil
93+ }
94+
95+ func (u * Updater [T ]) postActionHandler (ctx context.Context , globalOpContext * operation.OpContext , opContext * AfterOpContext , opType operation.OpType ) error {
96+ err := callback .GetCallback ().Execute (ctx , globalOpContext , opType )
97+ if err != nil {
98+ return err
99+ }
100+ for _ , afterHook := range u .afterHooks {
101+ err = afterHook (ctx , opContext )
102+ if err != nil {
103+ return err
104+ }
105+ }
106+ return nil
107+ }
108+
69109func (u * Updater [T ]) UpdateOne (ctx context.Context , opts ... * options.UpdateOptions ) (* mongo.UpdateResult , error ) {
70- opContext := operation .NewOpContext (u .collection , operation .WithFilter (u .filter ), operation .WithUpdate (u .updates ))
71- err := callback . GetCallback (). Execute ( ctx , opContext , operation .OpTypeBeforeUpdate )
110+ globalOpContext := operation .NewOpContext (u .collection , operation .WithFilter (u .filter ), operation .WithUpdate (u .updates ))
111+ err := u . preActionHandler ( ctx , globalOpContext , NewBeforeOpContext ( u . collection , NewCondContext ( u . filter , WithUpdates ( u . updates ))) , operation .OpTypeBeforeUpdate )
72112 if err != nil {
73113 return nil , err
74114 }
@@ -78,16 +118,16 @@ func (u *Updater[T]) UpdateOne(ctx context.Context, opts ...*options.UpdateOptio
78118 return nil , err
79119 }
80120
81- err = callback . GetCallback (). Execute ( ctx , opContext , operation .OpTypeAfterUpdate )
121+ err = u . postActionHandler ( ctx , globalOpContext , NewAfterOpContext ( u . collection , NewCondContext ( u . filter , WithUpdates ( u . updates ))) , operation .OpTypeAfterUpdate )
82122 if err != nil {
83123 return nil , err
84124 }
85125 return result , nil
86126}
87127
88128func (u * Updater [T ]) UpdateMany (ctx context.Context , opts ... * options.UpdateOptions ) (* mongo.UpdateResult , error ) {
89- opContext := operation .NewOpContext (u .collection , operation .WithFilter (u .filter ), operation .WithUpdate (u .updates ))
90- err := callback . GetCallback (). Execute ( ctx , opContext , operation .OpTypeBeforeUpdate )
129+ globalOpContext := operation .NewOpContext (u .collection , operation .WithFilter (u .filter ), operation .WithUpdate (u .updates ))
130+ err := u . preActionHandler ( ctx , globalOpContext , NewBeforeOpContext ( u . collection , NewCondContext ( u . filter , WithUpdates ( u . updates ))) , operation .OpTypeBeforeUpdate )
91131 if err != nil {
92132 return nil , err
93133 }
@@ -97,7 +137,7 @@ func (u *Updater[T]) UpdateMany(ctx context.Context, opts ...*options.UpdateOpti
97137 return nil , err
98138 }
99139
100- err = callback . GetCallback (). Execute ( ctx , opContext , operation .OpTypeAfterUpdate )
140+ err = u . postActionHandler ( ctx , globalOpContext , NewAfterOpContext ( u . collection , NewCondContext ( u . filter , WithUpdates ( u . updates ))) , operation .OpTypeAfterUpdate )
101141 if err != nil {
102142 return nil , err
103143 }
@@ -111,9 +151,9 @@ func (u *Updater[T]) Upsert(ctx context.Context, opts ...*options.ReplaceOptions
111151 opts [0 ].SetUpsert (true )
112152 }
113153
114- opContext := operation .NewOpContext (u .collection , operation .WithFilter (u .filter ), operation .WithReplacement (u .replacement ))
154+ globalOpContext := operation .NewOpContext (u .collection , operation .WithFilter (u .filter ), operation .WithReplacement (u .replacement ))
115155
116- err := callback . GetCallback (). Execute ( ctx , opContext , operation .OpTypeBeforeUpsert )
156+ err := u . preActionHandler ( ctx , globalOpContext , NewBeforeOpContext ( u . collection , NewCondContext ( u . filter , WithReplacement ( u . replacement ))) , operation .OpTypeBeforeUpsert )
117157 if err != nil {
118158 return nil , err
119159 }
@@ -123,7 +163,7 @@ func (u *Updater[T]) Upsert(ctx context.Context, opts ...*options.ReplaceOptions
123163 return nil , err
124164 }
125165
126- err = callback . GetCallback (). Execute ( ctx , opContext , operation .OpTypeAfterUpsert )
166+ err = u . postActionHandler ( ctx , globalOpContext , NewAfterOpContext ( u . collection , NewCondContext ( u . filter , WithReplacement ( u . replacement ))) , operation .OpTypeAfterUpsert )
127167 if err != nil {
128168 return nil , err
129169 }
0 commit comments