diff --git a/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/aws.go b/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/aws.go index 500e85f4d4d..331c287c5ed 100644 --- a/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/aws.go +++ b/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/aws.go @@ -17,6 +17,9 @@ import ( "go.opentelemetry.io/otel/propagation" semconv "go.opentelemetry.io/otel/semconv/v1.34.0" "go.opentelemetry.io/otel/trace" + + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" ) const ( @@ -52,6 +55,13 @@ func (m otelMiddlewares) initializeMiddlewareAfter(stack *middleware.Stack) erro out middleware.InitializeOutput, metadata middleware.Metadata, err error, ) { serviceID := v2Middleware.GetServiceID(ctx) + if serviceID == "SQS" { + // Handle SQS message attributes for trace propagation directly + if _, ok := m.propagator.(*SQSMessageAttributePropagator); ok { + m.injectSQSTraceContext(ctx, in.Parameters) + } + } + operation := v2Middleware.GetOperationName(ctx) region := v2Middleware.GetRegion(ctx) @@ -131,6 +141,18 @@ func (m otelMiddlewares) buildAttributes(ctx context.Context, in middleware.Init return attributes } +func (m otelMiddlewares) injectSQSTraceContext(ctx context.Context, input interface{}) { + switch v := input.(type) { + case *sqs.SendMessageInput: + if v.MessageAttributes == nil { + v.MessageAttributes = make(map[string]types.MessageAttributeValue) + } + sqsCarrier := &SQSMessageAttributeCarrier{Attributes: v.MessageAttributes} + m.propagator.Inject(ctx, sqsCarrier) + default: + } +} + func spanName(serviceID, operation string) string { spanName := serviceID if operation != "" { @@ -163,3 +185,13 @@ func AppendMiddlewares(apiOptions *[]func(*middleware.Stack) error, opts ...Opti } *apiOptions = append(*apiOptions, m.initializeMiddlewareBefore, m.initializeMiddlewareAfter, m.finalizeMiddlewareAfter, m.deserializeMiddleware) } + +// ExtractSQSTraceContext extracts trace context from SQS message attributes +// and returns a new context with the extracted trace information. +func ExtractSQSTraceContext(ctx context.Context, messageAttributes map[string]types.MessageAttributeValue) context.Context { + propagator := otel.GetTextMapPropagator() + carrier := &SQSMessageAttributeCarrier{ + Attributes: messageAttributes, + } + return propagator.Extract(ctx, carrier) +} diff --git a/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/config.go b/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/config.go index 7f7f276f188..5be54518c95 100755 --- a/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/config.go +++ b/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/config.go @@ -54,3 +54,8 @@ func WithAttributeBuilder(attributeBuilders ...AttributeBuilder) Option { cfg.AttributeBuilders = append(cfg.AttributeBuilders, attributeBuilders...) }) } + +// WithSQSMessageAttributesTracePropagation adds a TextMapPropagator that propagates trace context using SQS message attributes. +func WithSQSMessageAttributesTracePropagation() Option { + return WithTextMapPropagator(NewSQSMessageAttributePropagator()) +} diff --git a/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/sqs_trace_propagation.go b/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/sqs_trace_propagation.go new file mode 100644 index 00000000000..1c308ced65c --- /dev/null +++ b/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws/sqs_trace_propagation.go @@ -0,0 +1,77 @@ +package otelaws + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "go.opentelemetry.io/otel/propagation" +) + +// SQSMessageAttributeCarrier implements propagation.TextMapCarrier for SQS message attributes +type SQSMessageAttributeCarrier struct { + Attributes map[string]types.MessageAttributeValue +} + +func (c *SQSMessageAttributeCarrier) Get(key string) string { + if attr, exists := c.Attributes[key]; exists && attr.StringValue != nil { + return *attr.StringValue + } + return "" +} + +func (c *SQSMessageAttributeCarrier) Set(key, value string) { + if c.Attributes == nil { + c.Attributes = make(map[string]types.MessageAttributeValue) + } + c.Attributes[key] = types.MessageAttributeValue{ + StringValue: aws.String(value), + DataType: aws.String("String"), + } +} + +func (c *SQSMessageAttributeCarrier) Keys() []string { + keys := make([]string, 0, len(c.Attributes)) + for key := range c.Attributes { + keys = append(keys, key) + } + return keys +} + +// SQSMessageAttributePropagator implements propagation.TextMapPropagator for SQS message attributes +type SQSMessageAttributePropagator struct { + propagator propagation.TextMapPropagator +} + +func NewSQSMessageAttributePropagator() *SQSMessageAttributePropagator { + return &SQSMessageAttributePropagator{ + propagator: propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + ), + } +} + +func (p *SQSMessageAttributePropagator) Inject(ctx context.Context, carrier propagation.TextMapCarrier) { + if sqsCarrier, ok := carrier.(*SQSMessageAttributeCarrier); ok { + p.propagator.Inject(ctx, sqsCarrier) + } +} + +func (p *SQSMessageAttributePropagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context { + if sqsCarrier, ok := carrier.(*SQSMessageAttributeCarrier); ok { + return p.propagator.Extract(ctx, sqsCarrier) + } + return ctx +} + +func (p *SQSMessageAttributePropagator) Fields() []string { + return p.propagator.Fields() +} + +// Ensure SQSMessageAttributePropagator implements propagation.TextMapPropagator +// SQSMessageAttributeCarrier implements propagation.TextMapCarrier +var ( + _ propagation.TextMapPropagator = &SQSMessageAttributePropagator{} + _ propagation.TextMapCarrier = &SQSMessageAttributeCarrier{} +)