Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import (
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.34.0"
"go.opentelemetry.io/otel/trace"

"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
)

const (
Expand Down Expand Up @@ -52,6 +55,13 @@ func (m otelMiddlewares) initializeMiddlewareAfter(stack *middleware.Stack) erro
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
serviceID := v2Middleware.GetServiceID(ctx)
if serviceID == "SQS" {
// Handle SQS message attributes for trace propagation directly
if _, ok := m.propagator.(*SQSMessageAttributePropagator); ok {
m.injectSQSTraceContext(ctx, in.Parameters)
}
}

operation := v2Middleware.GetOperationName(ctx)
region := v2Middleware.GetRegion(ctx)

Expand Down Expand Up @@ -131,6 +141,18 @@ func (m otelMiddlewares) buildAttributes(ctx context.Context, in middleware.Init
return attributes
}

func (m otelMiddlewares) injectSQSTraceContext(ctx context.Context, input interface{}) {
switch v := input.(type) {
case *sqs.SendMessageInput:
if v.MessageAttributes == nil {
v.MessageAttributes = make(map[string]types.MessageAttributeValue)
}
sqsCarrier := &SQSMessageAttributeCarrier{Attributes: v.MessageAttributes}
m.propagator.Inject(ctx, sqsCarrier)
default:
}
}

func spanName(serviceID, operation string) string {
spanName := serviceID
if operation != "" {
Expand Down Expand Up @@ -163,3 +185,13 @@ func AppendMiddlewares(apiOptions *[]func(*middleware.Stack) error, opts ...Opti
}
*apiOptions = append(*apiOptions, m.initializeMiddlewareBefore, m.initializeMiddlewareAfter, m.finalizeMiddlewareAfter, m.deserializeMiddleware)
}

// ExtractSQSTraceContext extracts trace context from SQS message attributes
// and returns a new context with the extracted trace information.
func ExtractSQSTraceContext(ctx context.Context, messageAttributes map[string]types.MessageAttributeValue) context.Context {
propagator := otel.GetTextMapPropagator()
carrier := &SQSMessageAttributeCarrier{
Attributes: messageAttributes,
}
return propagator.Extract(ctx, carrier)
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,8 @@ func WithAttributeBuilder(attributeBuilders ...AttributeBuilder) Option {
cfg.AttributeBuilders = append(cfg.AttributeBuilders, attributeBuilders...)
})
}

// WithSQSMessageAttributesTracePropagation adds a TextMapPropagator that propagates trace context using SQS message attributes.
func WithSQSMessageAttributesTracePropagation() Option {
return WithTextMapPropagator(NewSQSMessageAttributePropagator())
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package otelaws

import (
"context"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
"go.opentelemetry.io/otel/propagation"
)

// SQSMessageAttributeCarrier implements propagation.TextMapCarrier for SQS message attributes
type SQSMessageAttributeCarrier struct {
Attributes map[string]types.MessageAttributeValue
}

func (c *SQSMessageAttributeCarrier) Get(key string) string {
if attr, exists := c.Attributes[key]; exists && attr.StringValue != nil {
return *attr.StringValue
}
return ""
}

func (c *SQSMessageAttributeCarrier) Set(key, value string) {
if c.Attributes == nil {
c.Attributes = make(map[string]types.MessageAttributeValue)
}
c.Attributes[key] = types.MessageAttributeValue{
StringValue: aws.String(value),
DataType: aws.String("String"),
}
}

func (c *SQSMessageAttributeCarrier) Keys() []string {
keys := make([]string, 0, len(c.Attributes))
for key := range c.Attributes {
keys = append(keys, key)
}
return keys
}

// SQSMessageAttributePropagator implements propagation.TextMapPropagator for SQS message attributes
type SQSMessageAttributePropagator struct {
propagator propagation.TextMapPropagator
}

func NewSQSMessageAttributePropagator() *SQSMessageAttributePropagator {
return &SQSMessageAttributePropagator{
propagator: propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
),
}
}

func (p *SQSMessageAttributePropagator) Inject(ctx context.Context, carrier propagation.TextMapCarrier) {
if sqsCarrier, ok := carrier.(*SQSMessageAttributeCarrier); ok {
p.propagator.Inject(ctx, sqsCarrier)
}
}

func (p *SQSMessageAttributePropagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context {
if sqsCarrier, ok := carrier.(*SQSMessageAttributeCarrier); ok {
return p.propagator.Extract(ctx, sqsCarrier)
}
return ctx
}

func (p *SQSMessageAttributePropagator) Fields() []string {
return p.propagator.Fields()
}

// Ensure SQSMessageAttributePropagator implements propagation.TextMapPropagator
// SQSMessageAttributeCarrier implements propagation.TextMapCarrier
var (
_ propagation.TextMapPropagator = &SQSMessageAttributePropagator{}
_ propagation.TextMapCarrier = &SQSMessageAttributeCarrier{}
)