@@ -21,28 +21,8 @@ type engineContextKeyType struct{}
2121
2222var engineContextKey = engineContextKeyType {}
2323
24- type xormContextType struct {
25- context.Context
26- engine Engine
27- }
28-
29- var xormContext * xormContextType
30-
31- func newContext (ctx context.Context , e Engine ) * xormContextType {
32- return & xormContextType {Context : ctx , engine : e }
33- }
34-
35- // Value shadows Value for context.Context but allows us to get ourselves and an Engined object
36- func (ctx * xormContextType ) Value (key any ) any {
37- if key == engineContextKey {
38- return ctx
39- }
40- return ctx .Context .Value (key )
41- }
42-
43- // WithContext returns this engine tied to this context
44- func (ctx * xormContextType ) WithContext (other context.Context ) * xormContextType {
45- return newContext (ctx , ctx .engine .Context (other ))
24+ func withContextEngine (ctx context.Context , e Engine ) context.Context {
25+ return context .WithValue (ctx , engineContextKey , e )
4626}
4727
4828var (
@@ -89,8 +69,8 @@ func contextSafetyCheck(e Engine) {
8969// GetEngine gets an existing db Engine/Statement or creates a new Session
9070func GetEngine (ctx context.Context ) (e Engine ) {
9171 defer func () { contextSafetyCheck (e ) }()
92- if e := getExistingEngine ( ctx ); e != nil {
93- return e
72+ if engine , ok := ctx . Value ( engineContextKey ).( Engine ); ok {
73+ return engine
9474 }
9575 return xormEngine .Context (ctx )
9676}
@@ -99,17 +79,6 @@ func GetXORMEngineForTesting() *xorm.Engine {
9979 return xormEngine
10080}
10181
102- // getExistingEngine gets an existing db Engine/Statement from this context or returns nil
103- func getExistingEngine (ctx context.Context ) (e Engine ) {
104- if engined , ok := ctx .(* xormContextType ); ok {
105- return engined .engine
106- }
107- if engined , ok := ctx .Value (engineContextKey ).(* xormContextType ); ok {
108- return engined .engine
109- }
110- return nil
111- }
112-
11382// Committer represents an interface to Commit or Close the Context
11483type Committer interface {
11584 Commit () error
@@ -152,24 +121,23 @@ func (c *halfCommitter) Close() error {
152121// And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function.
153122// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
154123func TxContext (parentCtx context.Context ) (context.Context , Committer , error ) {
155- if sess , ok := inTransaction (parentCtx ); ok {
156- return newContext (parentCtx , sess ), & halfCommitter {committer : sess }, nil
124+ if sess := getTransactionSession (parentCtx ); sess != nil {
125+ return withContextEngine (parentCtx , sess ), & halfCommitter {committer : sess }, nil
157126 }
158127
159128 sess := xormEngine .NewSession ()
160129 if err := sess .Begin (); err != nil {
161130 _ = sess .Close ()
162131 return nil , nil , err
163132 }
164-
165- return newContext (xormContext , sess ), sess , nil
133+ return withContextEngine (parentCtx , sess ), sess , nil
166134}
167135
168136// WithTx represents executing database operations on a transaction, if the transaction exist,
169137// this function will reuse it otherwise will create a new one and close it when finished.
170138func WithTx (parentCtx context.Context , f func (ctx context.Context ) error ) error {
171- if sess , ok := inTransaction (parentCtx ); ok {
172- err := f (newContext (parentCtx , sess ))
139+ if sess := getTransactionSession (parentCtx ); sess != nil {
140+ err := f (withContextEngine (parentCtx , sess ))
173141 if err != nil {
174142 // rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
175143 _ = sess .Close ()
@@ -195,7 +163,7 @@ func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error)
195163 return err
196164 }
197165
198- if err := f (newContext (parentCtx , sess )); err != nil {
166+ if err := f (withContextEngine (parentCtx , sess )); err != nil {
199167 return err
200168 }
201169
@@ -340,25 +308,13 @@ func TableName(bean any) string {
340308
341309// InTransaction returns true if the engine is in a transaction otherwise return false
342310func InTransaction (ctx context.Context ) bool {
343- _ , ok := inTransaction (ctx )
344- return ok
311+ return getTransactionSession (ctx ) != nil
345312}
346313
347- func inTransaction (ctx context.Context ) (* xorm.Session , bool ) {
348- e := getExistingEngine (ctx )
349- if e == nil {
350- return nil , false
351- }
352-
353- switch t := e .(type ) {
354- case * xorm.Engine :
355- return nil , false
356- case * xorm.Session :
357- if t .IsInTx () {
358- return t , true
359- }
360- return nil , false
361- default :
362- return nil , false
314+ func getTransactionSession (ctx context.Context ) * xorm.Session {
315+ e , _ := ctx .Value (engineContextKey ).(Engine )
316+ if sess , ok := e .(* xorm.Session ); ok && sess .IsInTx () {
317+ return sess
363318 }
319+ return nil
364320}
0 commit comments