Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ type ClientOptions struct {
// PropagateTraceparent is used to control whether the W3C Trace Context HTTP traceparent header
// is propagated on outgoing http requests.
PropagateTraceparent bool
// StrictTraceContinuation is used to control trace continuation from 3rd party services that happen to be
// instrumented by Sentry.
//
// Enabling the option means that the SDK will require the org ids from baggage to match for continuing the trace.
StrictTraceContinuation bool
// OrgID configures the orgID used for trace propagation and features like StrictTraceContinuation.
//
// In most cases the orgID is already parsed from the DSN. This option should be used when non-standard Sentry DSNs
// are used, such as self-hosted or when using a local Relay.
OrgID uint64
// List of regexp strings that will be used to match against event's message
// and if applicable, caught errors type and value.
// If the match is found, then a whole event will be dropped.
Expand Down Expand Up @@ -404,7 +414,9 @@ func NewClient(options ClientOptions) (*Client, error) {
client.batchMeter = newMetricBatchProcessor(&client)
client.batchMeter.Start()
}

if options.OrgID != 0 && client.dsn != nil {
client.dsn.SetOrgID(options.OrgID)
}
client.setupIntegrations()

return &client, nil
Expand Down
20 changes: 20 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,26 @@ func TestSampleRate(t *testing.T) {
})
}
}
func TestClient_ParseOrgID(t *testing.T) {
c, err := NewClient(ClientOptions{
Dsn: "https://example@o1.ingest.us.sentry.io/1337",
})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, uint64(1), c.dsn.GetOrgID(), "Custom org id should override the DSN parsed one")
}

func TestClientOptions_OrgIDShouldOverrideParsed(t *testing.T) {
c, err := NewClient(ClientOptions{
Dsn: "https://example@o1.ingest.us.sentry.io/1337",
OrgID: 2,
})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, uint64(2), c.dsn.GetOrgID(), "Custom org id should override the DSN parsed one")
}
Comment on lines +903 to +922
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add another test for self-hosted DSN?

An example would be: https://example@sentry.example.com/1337

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would need to use the OrgID option anyways, since we can't parse an id from the DSN, so TestClientOptions_OrgIDShouldOverrideParsed should verify both, but I can add a small one just to be sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. But since it's self-hosted, the default OrgId will be 0 though. It should be considered as non-empty value.


func BenchmarkProcessEvent(b *testing.B) {
c, err := NewClient(ClientOptions{
Expand Down
8 changes: 7 additions & 1 deletion dynamic_sampling_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ func DynamicSamplingContextFromTransaction(span *Span) DynamicSamplingContext {
if publicKey := dsn.GetPublicKey(); publicKey != "" {
entries["public_key"] = publicKey
}
if orgID := dsn.GetOrgID(); orgID != 0 {
entries["org_id"] = strconv.FormatUint(orgID, 10)
}
}
if release := client.options.Release; release != "" {
entries["release"] = release
Expand Down Expand Up @@ -113,7 +116,7 @@ func (d DynamicSamplingContext) String() string {
return baggage.String()
}

// Constructs a new DynamicSamplingContext using a scope and client. Accessing
// DynamicSamplingContextFromScope Constructs a new DynamicSamplingContext using a scope and client. Accessing
// fields on the scope are not thread safe, and this function should only be
// called within scope methods.
func DynamicSamplingContextFromScope(scope *Scope, client *Client) DynamicSamplingContext {
Expand All @@ -139,6 +142,9 @@ func DynamicSamplingContextFromScope(scope *Scope, client *Client) DynamicSampli
if publicKey := dsn.GetPublicKey(); publicKey != "" {
entries["public_key"] = publicKey
}
if orgID := dsn.GetOrgID(); orgID != 0 {
entries["org_id"] = strconv.FormatUint(orgID, 10)
}
}
if release := client.options.Release; release != "" {
entries["release"] = release
Expand Down
25 changes: 25 additions & 0 deletions internal/protocol/dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Dsn struct {
port int
path string
projectID string
orgID uint64
}

// NewDsn creates a Dsn by parsing rawURL. Most users will never call this
Expand Down Expand Up @@ -90,6 +91,17 @@ func NewDsn(rawURL string) (*Dsn, error) {
return nil, &DsnParseError{"empty host"}
}

// OrgID (optional)
var orgID uint64
parts := strings.Split(host, ".")
orgPart := parts[0]
if len(orgPart) >= 2 && orgPart[0] == 'o' {
parsedOrgID, err := strconv.ParseUint(orgPart[1:], 10, 64)
if err == nil {
orgID = parsedOrgID
}
}

// Port
var port int
if p := parsedURL.Port(); p != "" {
Expand Down Expand Up @@ -126,6 +138,7 @@ func NewDsn(rawURL string) (*Dsn, error) {
port: port,
path: path,
projectID: projectID,
orgID: orgID,
}, nil
}

Expand Down Expand Up @@ -182,6 +195,18 @@ func (dsn Dsn) GetProjectID() string {
return dsn.projectID
}

// GetOrgID returns the orgID that was parsed from the DSN.
func (dsn Dsn) GetOrgID() uint64 {
return dsn.orgID
}

// SetOrgID sets the orgID used for trace continuation.
//
// This function is used for overriding the orgID parsed from the DSN.
func (dsn *Dsn) SetOrgID(orgID uint64) {
dsn.orgID = orgID
}

// GetAPIURL returns the URL of the envelope endpoint of the project
// associated with the DSN.
func (dsn Dsn) GetAPIURL() *url.URL {
Expand Down
81 changes: 70 additions & 11 deletions tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -953,8 +953,15 @@ func WithSpanOrigin(origin SpanOrigin) SpanOption {
func ContinueTrace(hub *Hub, traceparent, baggage string) SpanOption {
scope := hub.Scope()
propagationContext, _ := PropagationContextFromHeaders(traceparent, baggage)
scope.SetPropagationContext(propagationContext)
client := hub.Client()

if !shouldContinueTrace(client, propagationContext.DynamicSamplingContext) {
propagationContext = NewPropagationContext()
traceparent = ""
baggage = ""
}

scope.SetPropagationContext(propagationContext)
return ContinueFromHeaders(traceparent, baggage)
}

Expand All @@ -973,19 +980,35 @@ func ContinueFromRequest(r *http.Request) SpanOption {
// an existing TraceID and propagates the Dynamic Sampling context.
func ContinueFromHeaders(trace, baggage string) SpanOption {
return func(s *Span) {
if trace != "" {
s.updateFromSentryTrace([]byte(trace))
if trace == "" {
return
}

if baggage != "" {
s.updateFromBaggage([]byte(baggage))
// Parse baggage first to get org_id for comparison
var dsc DynamicSamplingContext
if baggage != "" {
parsed, err := DynamicSamplingContextFromHeader([]byte(baggage))
if err == nil {
dsc = parsed
}
}

// In case a sentry-trace header is present but there are no sentry-related
// values in the baggage, create an empty, frozen DynamicSamplingContext.
if !s.dynamicSamplingContext.HasEntries() {
s.dynamicSamplingContext = DynamicSamplingContext{
Frozen: true,
}
client := hubFromContext(s.ctx).Client()
if !shouldContinueTrace(client, dsc) {
return // leave span unchanged → behaves as head of trace
}

s.updateFromSentryTrace([]byte(trace))

if baggage != "" {
s.updateFromBaggage([]byte(baggage))
}

// In case a sentry-trace header is present but there are no sentry-related
// values in the baggage, create an empty, frozen DynamicSamplingContext.
if !s.dynamicSamplingContext.HasEntries() {
s.dynamicSamplingContext = DynamicSamplingContext{
Frozen: true,
}
}
}
Expand All @@ -998,6 +1021,10 @@ func ContinueFromTrace(trace string) SpanOption {
if trace == "" {
return
}
client := hubFromContext(s.ctx).Client()
if !shouldContinueTrace(client, DynamicSamplingContext{}) {
return
}
s.updateFromSentryTrace([]byte(trace))
}
}
Expand Down Expand Up @@ -1077,3 +1104,35 @@ func HTTPtoSpanStatus(code int) SpanStatus {
}
return SpanStatusUnknown
}

func shouldContinueTrace(client *Client, dsc DynamicSamplingContext) bool {
if client == nil {
return true
}

var sdkOrgID uint64
if client.dsn != nil {
sdkOrgID = client.dsn.GetOrgID()
}

baggageOrgStr := dsc.Entries["org_id"]
baggageOrgID := uint64(0)
if baggageOrgStr != "" {
baggageOrgID, _ = strconv.ParseUint(baggageOrgStr, 10, 64)
}

// we reject non-matching orgs regardless of strict mode
if sdkOrgID != 0 && baggageOrgID != 0 && sdkOrgID != baggageOrgID {
return false
}

// If strict mode is on, both must be present and match
if client.options.StrictTraceContinuation {
if sdkOrgID == 0 && baggageOrgID == 0 {
return true
}
return sdkOrgID == baggageOrgID
}

return true
}
74 changes: 71 additions & 3 deletions tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ func TestContinueSpanFromRequest(t *testing.T) {
sampled := sampled
t.Run(sampled.String(), func(t *testing.T) {
var s Span
s.ctx = context.Background()
hkey := http.CanonicalHeaderKey("sentry-trace")
hval := (&Span{
TraceID: traceID,
Expand Down Expand Up @@ -585,12 +586,13 @@ func TestContinueTransactionFromHeaders(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Span{}
s.ctx = context.Background()
spanOption := ContinueFromHeaders(tt.traceStr, tt.baggageStr)
spanOption(s)

if diff := cmp.Diff(tt.wantSpan, s, cmp.Options{
cmp.AllowUnexported(Span{}),
cmpopts.IgnoreFields(Span{}, "mu", "finishOnce"),
cmpopts.IgnoreFields(Span{}, "ctx", "mu", "finishOnce"),
}); diff != "" {
t.Fatalf("Expected no difference on spans, got: %s", diff)
}
Expand All @@ -605,13 +607,14 @@ func TestContinueSpanFromTrace(t *testing.T) {
for _, sampled := range []Sampled{SampledTrue, SampledFalse, SampledUndefined} {
sampled := sampled
t.Run(sampled.String(), func(t *testing.T) {
var s Span
s := &Span{}
s.ctx = context.Background()
trace := (&Span{
TraceID: traceID,
SpanID: spanID,
Sampled: sampled,
}).ToSentryTrace()
ContinueFromTrace(trace)(&s)
ContinueFromTrace(trace)(s)
if s.TraceID != traceID {
t.Errorf("got %q, want %q", s.TraceID, traceID)
}
Expand Down Expand Up @@ -1287,3 +1290,68 @@ func TestSpanScopeManagement(t *testing.T) {
t.Errorf("expected SpanID %s, got %s", transaction.SpanID, spanID)
}
}

func TestStrictTraceContinuation(t *testing.T) {
incomingTraceID := TraceIDFromHex("bc6d53f15eb88f4320054569b8c553d4")
sentryTrace := "bc6d53f15eb88f4320054569b8c553d4-b72fa28504b07285-1"

baggageWithOrg := func(orgID string) string {
return "sentry-org_id=" + orgID + ",sentry-trace_id=bc6d53f15eb88f4320054569b8c553d4"
}
baggageWithoutOrg := "sentry-trace_id=bc6d53f15eb88f4320054569b8c553d4"

tests := []struct {
name string
baggageOrgID string
sdkOrgID uint64
strict bool
wantContinued bool
}{
{"strict=false, baggage=1, sdk=1", "1", 1, false, true},
{"strict=false, baggage=none, sdk=1", "", 1, false, true},
{"strict=false, baggage=1, sdk=none", "1", 0, false, true},
{"strict=false, baggage=none, sdk=none", "", 0, false, true},
{"strict=false, baggage=1, sdk=2", "1", 2, false, false},

{"strict=true, baggage=1, sdk=1", "1", 1, true, true},
{"strict=true, baggage=none, sdk=1", "", 1, true, false},
{"strict=true, baggage=1, sdk=none", "1", 0, true, false},
{"strict=true, baggage=none, sdk=none", "", 0, true, true},
{"strict=true, baggage=1, sdk=2", "1", 2, true, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := &MockTransport{}
ctx := NewTestContext(ClientOptions{
Dsn: testDsn,
EnableTracing: true,
TracesSampleRate: 1.0,
Transport: transport,
StrictTraceContinuation: tt.strict,
OrgID: tt.sdkOrgID,
})

baggage := baggageWithoutOrg
if tt.baggageOrgID != "" {
baggage = baggageWithOrg(tt.baggageOrgID)
}

hub := GetHubFromContext(ctx)
transaction := StartTransaction(ctx, "test",
ContinueTrace(hub, sentryTrace, baggage),
)
transaction.Finish()

if tt.wantContinued {
if transaction.TraceID != incomingTraceID {
t.Errorf("expected trace to be continued, got new TraceID %s", transaction.TraceID)
}
} else {
if transaction.TraceID == incomingTraceID {
t.Errorf("expected new trace, but got continued TraceID %s", transaction.TraceID)
}
}
})
}
}
Loading