| 
 | 1 | +package grpcclient  | 
 | 2 | + | 
 | 3 | +import (  | 
 | 4 | +	"context"  | 
 | 5 | +	"errors"  | 
 | 6 | +	"testing"  | 
 | 7 | + | 
 | 8 | +	otgrpc "github.com/opentracing-contrib/go-grpc"  | 
 | 9 | +	"github.com/opentracing/opentracing-go"  | 
 | 10 | +	"github.com/opentracing/opentracing-go/mocktracer"  | 
 | 11 | +	"github.com/stretchr/testify/require"  | 
 | 12 | +	"google.golang.org/grpc"  | 
 | 13 | +	"google.golang.org/grpc/metadata"  | 
 | 14 | +)  | 
 | 15 | + | 
 | 16 | +type mockClientStream struct {  | 
 | 17 | +	recvErr error  | 
 | 18 | +}  | 
 | 19 | + | 
 | 20 | +func (m *mockClientStream) RecvMsg(msg interface{}) error {  | 
 | 21 | +	return m.recvErr  | 
 | 22 | +}  | 
 | 23 | + | 
 | 24 | +func (m *mockClientStream) Header() (metadata.MD, error) {  | 
 | 25 | +	return nil, nil  | 
 | 26 | +}  | 
 | 27 | + | 
 | 28 | +func (m *mockClientStream) Trailer() metadata.MD {  | 
 | 29 | +	return nil  | 
 | 30 | +}  | 
 | 31 | + | 
 | 32 | +func (m *mockClientStream) CloseSend() error {  | 
 | 33 | +	return nil  | 
 | 34 | +}  | 
 | 35 | + | 
 | 36 | +func (m *mockClientStream) Context() context.Context {  | 
 | 37 | +	return context.Background()  | 
 | 38 | +}  | 
 | 39 | + | 
 | 40 | +func (m *mockClientStream) SendMsg(interface{}) error {  | 
 | 41 | +	return nil  | 
 | 42 | +}  | 
 | 43 | + | 
 | 44 | +func TestUnwrapErrorStreamClientInterceptor(t *testing.T) {  | 
 | 45 | +	// Create a mock tracer  | 
 | 46 | +	tracer := mocktracer.New()  | 
 | 47 | +	opentracing.SetGlobalTracer(tracer)  | 
 | 48 | + | 
 | 49 | +	originalErr := errors.New("original error")  | 
 | 50 | +	// Create a mock stream that returns the original error  | 
 | 51 | +	mockStream := &mockClientStream{  | 
 | 52 | +		recvErr: originalErr,  | 
 | 53 | +	}  | 
 | 54 | + | 
 | 55 | +	// Create a mock streamer that returns our mock stream  | 
 | 56 | +	mockStreamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {  | 
 | 57 | +		return mockStream, nil  | 
 | 58 | +	}  | 
 | 59 | + | 
 | 60 | +	// Create the interceptor chain  | 
 | 61 | +	otStreamInterceptor := otgrpc.OpenTracingStreamClientInterceptor(tracer)  | 
 | 62 | +	interceptors := []grpc.StreamClientInterceptor{  | 
 | 63 | +		unwrapErrorStreamClientInterceptor(),  | 
 | 64 | +		otStreamInterceptor,  | 
 | 65 | +	}  | 
 | 66 | + | 
 | 67 | +	// Chain the interceptors  | 
 | 68 | +	chainedStreamer := mockStreamer  | 
 | 69 | +	for i := len(interceptors) - 1; i >= 0; i-- {  | 
 | 70 | +		chainedStreamer = func(interceptor grpc.StreamClientInterceptor, next grpc.Streamer) grpc.Streamer {  | 
 | 71 | +			return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {  | 
 | 72 | +				return interceptor(ctx, desc, cc, method, next, opts...)  | 
 | 73 | +			}  | 
 | 74 | +		}(interceptors[i], chainedStreamer)  | 
 | 75 | +	}  | 
 | 76 | + | 
 | 77 | +	// Call the chained streamer  | 
 | 78 | +	ctx := context.Background()  | 
 | 79 | +	stream, err := chainedStreamer(ctx, &grpc.StreamDesc{}, nil, "test")  | 
 | 80 | +	require.NoError(t, err)  | 
 | 81 | +	var msg interface{}  | 
 | 82 | +	err = stream.RecvMsg(&msg)  | 
 | 83 | +	require.Error(t, err)  | 
 | 84 | +	require.EqualError(t, err, originalErr.Error())  | 
 | 85 | + | 
 | 86 | +	// Only wrap OpenTracingStreamClientInterceptor.  | 
 | 87 | +	chainedStreamerWithoutUnwrapErr := func(interceptor grpc.StreamClientInterceptor, next grpc.Streamer) grpc.Streamer {  | 
 | 88 | +		return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {  | 
 | 89 | +			return interceptor(ctx, desc, cc, method, next, opts...)  | 
 | 90 | +		}  | 
 | 91 | +	}(otStreamInterceptor, mockStreamer)  | 
 | 92 | +	stream, err = chainedStreamerWithoutUnwrapErr(ctx, &grpc.StreamDesc{}, nil, "test")  | 
 | 93 | +	require.NoError(t, err)  | 
 | 94 | +	err = stream.RecvMsg(&msg)  | 
 | 95 | +	require.Error(t, err)  | 
 | 96 | +	// Error is wrapped by OpenTracingStreamClientInterceptor and not unwrapped.  | 
 | 97 | +	require.Contains(t, err.Error(), "failed to receive message")  | 
 | 98 | +}  | 
0 commit comments