@@ -8,17 +8,27 @@ import (
88 "testing"
99
1010 "github.com/hypertrace/goagent/instrumentation/opentelemetry/internal/tracetesting"
11- "github.com/hypertrace/goagent/sdk/filter"
11+ "github.com/hypertrace/goagent/sdk"
12+ "github.com/hypertrace/goagent/sdk/filter/result"
13+ sdkSQL "github.com/hypertrace/goagent/sdk/instrumentation/database/sql"
1214 _ "github.com/mattn/go-sqlite3"
1315 "github.com/stretchr/testify/assert"
1416 sdktrace "go.opentelemetry.io/otel/sdk/trace"
1517 apitrace "go.opentelemetry.io/otel/trace"
1618)
1719
20+ type mockFilter struct {
21+ evaluator func (span sdk.Span ) result.FilterResult
22+ }
23+
24+ func (f * mockFilter ) Evaluate (span sdk.Span ) result.FilterResult {
25+ return f .evaluator (span )
26+ }
27+
1828func createDB (t * testing.T ) (* sql.DB , func () []sdktrace.ReadOnlySpan ) {
1929 _ , flusher := tracetesting .InitTracer ()
2030
21- driverName , err := Register ("sqlite3" , filter. NoopFilter {} )
31+ driverName , err := Register ("sqlite3" , nil )
2232 if err != nil {
2333 t .Fatalf ("unable to register driver" )
2434 }
@@ -188,3 +198,57 @@ func TestTxWithRollbackSuccess(t *testing.T) {
188198
189199 db .Close ()
190200}
201+
202+ func TestFilter (t * testing.T ) {
203+ _ , flusher := tracetesting .InitTracer ()
204+
205+ driverName , err := Register ("sqlite3" , & sdkSQL.Options {
206+ Filter : & mockFilter {
207+ evaluator : func (span sdk.Span ) result.FilterResult {
208+ assert .Equal (t , span .GetAttributes ().GetValue ("span.kind" ), "client" )
209+
210+ span .SetAttribute ("span.type" , "nospan" )
211+ return result.FilterResult {}
212+ },
213+ },
214+ })
215+ if err != nil {
216+ t .Fatalf ("unable to register driver" )
217+ }
218+
219+ db , err := sql .Open (driverName , "file:test.db?cache=shared&mode=memory" )
220+ if err != nil {
221+ t .Fatal (err )
222+ }
223+
224+ rows , err := db .Query ("SELECT 1 WHERE 1 = ?" , 1 )
225+ if err != nil {
226+ t .Fatalf ("unexpected error: %s" , err .Error ())
227+ }
228+ defer rows .Close ()
229+
230+ for rows .Next () {
231+ var n int
232+ if err = rows .Scan (& n ); err != nil {
233+ t .Fatalf ("unexpected error: %s" , err .Error ())
234+ }
235+ }
236+ if err = rows .Err (); err != nil {
237+ t .Fatalf ("unexpected error: %s" , err .Error ())
238+ }
239+
240+ spans := flusher ()
241+ assert .Equal (t , 1 , len (spans ))
242+
243+ span := spans [0 ]
244+ assert .Equal (t , "db:query" , span .Name ())
245+ assert .Equal (t , apitrace .SpanKindClient , span .SpanKind ())
246+
247+ attrs := tracetesting .LookupAttributes (span .Attributes ())
248+ assert .Equal (t , "SELECT 1 WHERE 1 = ?" , attrs .Get ("db.statement" ).AsString ())
249+ assert .Equal (t , "sqlite" , attrs .Get ("db.system" ).AsString ())
250+ assert .False (t , attrs .Has ("error" ))
251+ assert .Equal (t , "nospan" , attrs .Get ("span.type" ).AsString ())
252+
253+ db .Close ()
254+ }
0 commit comments