@@ -21,28 +21,8 @@ type engineContextKeyType struct{}
21
21
22
22
var engineContextKey = engineContextKeyType {}
23
23
24
- type xormContextType struct {
25
- context.Context
26
- engine Engine
27
- }
28
-
29
- var xormContext * xormContextType
30
-
31
- func newContext (ctx context.Context , e Engine ) * xormContextType {
32
- return & xormContextType {Context : ctx , engine : e }
33
- }
34
-
35
- // Value shadows Value for context.Context but allows us to get ourselves and an Engined object
36
- func (ctx * xormContextType ) Value (key any ) any {
37
- if key == engineContextKey {
38
- return ctx
39
- }
40
- return ctx .Context .Value (key )
41
- }
42
-
43
- // WithContext returns this engine tied to this context
44
- func (ctx * xormContextType ) WithContext (other context.Context ) * xormContextType {
45
- return newContext (ctx , ctx .engine .Context (other ))
24
+ func withContextEngine (ctx context.Context , e Engine ) context.Context {
25
+ return context .WithValue (ctx , engineContextKey , e )
46
26
}
47
27
48
28
var (
@@ -89,8 +69,8 @@ func contextSafetyCheck(e Engine) {
89
69
// GetEngine gets an existing db Engine/Statement or creates a new Session
90
70
func GetEngine (ctx context.Context ) (e Engine ) {
91
71
defer func () { contextSafetyCheck (e ) }()
92
- if e := getExistingEngine ( ctx ); e != nil {
93
- return e
72
+ if engine , ok := ctx . Value ( engineContextKey ).( Engine ); ok {
73
+ return engine
94
74
}
95
75
return xormEngine .Context (ctx )
96
76
}
@@ -99,17 +79,6 @@ func GetXORMEngineForTesting() *xorm.Engine {
99
79
return xormEngine
100
80
}
101
81
102
- // getExistingEngine gets an existing db Engine/Statement from this context or returns nil
103
- func getExistingEngine (ctx context.Context ) (e Engine ) {
104
- if engined , ok := ctx .(* xormContextType ); ok {
105
- return engined .engine
106
- }
107
- if engined , ok := ctx .Value (engineContextKey ).(* xormContextType ); ok {
108
- return engined .engine
109
- }
110
- return nil
111
- }
112
-
113
82
// Committer represents an interface to Commit or Close the Context
114
83
type Committer interface {
115
84
Commit () error
@@ -152,24 +121,23 @@ func (c *halfCommitter) Close() error {
152
121
// And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function.
153
122
// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
154
123
func TxContext (parentCtx context.Context ) (context.Context , Committer , error ) {
155
- if sess , ok := inTransaction (parentCtx ); ok {
156
- return newContext (parentCtx , sess ), & halfCommitter {committer : sess }, nil
124
+ if sess := getTransactionSession (parentCtx ); sess != nil {
125
+ return withContextEngine (parentCtx , sess ), & halfCommitter {committer : sess }, nil
157
126
}
158
127
159
128
sess := xormEngine .NewSession ()
160
129
if err := sess .Begin (); err != nil {
161
130
_ = sess .Close ()
162
131
return nil , nil , err
163
132
}
164
-
165
- return newContext (xormContext , sess ), sess , nil
133
+ return withContextEngine (parentCtx , sess ), sess , nil
166
134
}
167
135
168
136
// WithTx represents executing database operations on a transaction, if the transaction exist,
169
137
// this function will reuse it otherwise will create a new one and close it when finished.
170
138
func WithTx (parentCtx context.Context , f func (ctx context.Context ) error ) error {
171
- if sess , ok := inTransaction (parentCtx ); ok {
172
- err := f (newContext (parentCtx , sess ))
139
+ if sess := getTransactionSession (parentCtx ); sess != nil {
140
+ err := f (withContextEngine (parentCtx , sess ))
173
141
if err != nil {
174
142
// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
175
143
_ = sess .Close ()
@@ -195,7 +163,7 @@ func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error)
195
163
return err
196
164
}
197
165
198
- if err := f (newContext (parentCtx , sess )); err != nil {
166
+ if err := f (withContextEngine (parentCtx , sess )); err != nil {
199
167
return err
200
168
}
201
169
@@ -340,25 +308,13 @@ func TableName(bean any) string {
340
308
341
309
// InTransaction returns true if the engine is in a transaction otherwise return false
342
310
func InTransaction (ctx context.Context ) bool {
343
- _ , ok := inTransaction (ctx )
344
- return ok
311
+ return getTransactionSession (ctx ) != nil
345
312
}
346
313
347
- func inTransaction (ctx context.Context ) (* xorm.Session , bool ) {
348
- e := getExistingEngine (ctx )
349
- if e == nil {
350
- return nil , false
351
- }
352
-
353
- switch t := e .(type ) {
354
- case * xorm.Engine :
355
- return nil , false
356
- case * xorm.Session :
357
- if t .IsInTx () {
358
- return t , true
359
- }
360
- return nil , false
361
- default :
362
- return nil , false
314
+ func getTransactionSession (ctx context.Context ) * xorm.Session {
315
+ e , _ := ctx .Value (engineContextKey ).(Engine )
316
+ if sess , ok := e .(* xorm.Session ); ok && sess .IsInTx () {
317
+ return sess
363
318
}
319
+ return nil
364
320
}
0 commit comments