Skip to content

Commit c5d161d

Browse files
Restore previous context in after hooks (#50)
* Try recovering previous context * Clean up * Run gofmt * Add context restoration tests
1 parent c42d34c commit c5d161d

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

tracing/tracing.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package tracing
22

33
import (
4+
"context"
45
"database/sql"
56
"database/sql/driver"
67
"fmt"
@@ -108,10 +109,16 @@ func (p otelPlugin) Initialize(db *gorm.DB) (err error) {
108109
return firstErr
109110
}
110111

112+
type contextWrapper struct {
113+
context.Context
114+
parent context.Context
115+
}
116+
111117
func (p *otelPlugin) before(spanName string) gormHookFunc {
112118
return func(tx *gorm.DB) {
119+
parentCtx := tx.Statement.Context
113120
ctx, span := p.tracer.Start(tx.Statement.Context, spanName, trace.WithSpanKind(trace.SpanKindClient))
114-
tx.Statement.Context = ctx
121+
tx.Statement.Context = contextWrapper{ctx, parentCtx}
115122

116123
if !p.excludeServerAddress {
117124
// `server.address` is required in the latest semconv
@@ -141,6 +148,13 @@ func (p *otelPlugin) before(spanName string) gormHookFunc {
141148

142149
func (p *otelPlugin) after() gormHookFunc {
143150
return func(tx *gorm.DB) {
151+
defer func() {
152+
if c, ok := tx.Statement.Context.(contextWrapper); ok {
153+
// recover previous context
154+
tx.Statement.Context = c.parent
155+
}
156+
}()
157+
144158
span := trace.SpanFromContext(tx.Statement.Context)
145159
if !span.IsRecording() {
146160
return

tracing/tracing_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"go.opentelemetry.io/otel/sdk/trace/tracetest"
1313
semconv "go.opentelemetry.io/otel/semconv/v1.30.0"
1414
"go.opentelemetry.io/otel/trace"
15+
"go.opentelemetry.io/otel/trace/noop"
1516
"gorm.io/driver/postgres"
1617
"gorm.io/driver/sqlite"
1718
"gorm.io/gorm"
@@ -167,6 +168,41 @@ func TestOtel(t *testing.T) {
167168
}
168169
}
169170

171+
func TestOtelPlugin_ContextRestoration(t *testing.T) {
172+
provider := noop.NewTracerProvider()
173+
p := &otelPlugin{provider: provider}
174+
p.tracer = provider.Tracer("test")
175+
176+
origCtx := context.WithValue(context.Background(), "foo", "bar")
177+
db := &gorm.DB{
178+
Statement: &gorm.Statement{Context: origCtx},
179+
Config: &gorm.Config{},
180+
}
181+
182+
// before should wrap context
183+
before := p.before("test-span")
184+
before(db)
185+
cw, ok := db.Statement.Context.(contextWrapper)
186+
require.True(t, ok)
187+
require.Equal(t, origCtx, cw.parent)
188+
189+
// after should restore context
190+
after := p.after()
191+
after(db)
192+
require.Equal(t, origCtx, db.Statement.Context)
193+
194+
origCtx = context.Background()
195+
db = &gorm.DB{
196+
Statement: &gorm.Statement{Context: origCtx},
197+
Config: &gorm.Config{},
198+
}
199+
200+
// after should not panic if context is not a contextWrapper
201+
after = p.after()
202+
require.NotPanics(t, func() { after(db) })
203+
require.Equal(t, origCtx, db.Statement.Context)
204+
}
205+
170206
func attrMap(attrs []attribute.KeyValue) map[attribute.Key]attribute.Value {
171207
m := make(map[attribute.Key]attribute.Value, len(attrs))
172208
for _, kv := range attrs {

0 commit comments

Comments
 (0)