@@ -12,6 +12,7 @@ import (
1212 "go.opentelemetry.io/otel/sdk/trace/tracetest"
1313 semconv "go.opentelemetry.io/otel/semconv/v1.30.0"
1414 "go.opentelemetry.io/otel/trace"
15+ "go.opentelemetry.io/otel/trace/noop"
1516 "gorm.io/driver/postgres"
1617 "gorm.io/driver/sqlite"
1718 "gorm.io/gorm"
@@ -167,6 +168,41 @@ func TestOtel(t *testing.T) {
167168 }
168169}
169170
171+ func TestOtelPlugin_ContextRestoration (t * testing.T ) {
172+ provider := noop .NewTracerProvider ()
173+ p := & otelPlugin {provider : provider }
174+ p .tracer = provider .Tracer ("test" )
175+
176+ origCtx := context .WithValue (context .Background (), "foo" , "bar" )
177+ db := & gorm.DB {
178+ Statement : & gorm.Statement {Context : origCtx },
179+ Config : & gorm.Config {},
180+ }
181+
182+ // before should wrap context
183+ before := p .before ("test-span" )
184+ before (db )
185+ cw , ok := db .Statement .Context .(contextWrapper )
186+ require .True (t , ok )
187+ require .Equal (t , origCtx , cw .parent )
188+
189+ // after should restore context
190+ after := p .after ()
191+ after (db )
192+ require .Equal (t , origCtx , db .Statement .Context )
193+
194+ origCtx = context .Background ()
195+ db = & gorm.DB {
196+ Statement : & gorm.Statement {Context : origCtx },
197+ Config : & gorm.Config {},
198+ }
199+
200+ // after should not panic if context is not a contextWrapper
201+ after = p .after ()
202+ require .NotPanics (t , func () { after (db ) })
203+ require .Equal (t , origCtx , db .Statement .Context )
204+ }
205+
170206func attrMap (attrs []attribute.KeyValue ) map [attribute.Key ]attribute.Value {
171207 m := make (map [attribute.Key ]attribute.Value , len (attrs ))
172208 for _ , kv := range attrs {
0 commit comments