diff --git a/proxy/proxy.go b/proxy/proxy.go index ba717f0..cb3d6c9 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -15,6 +15,7 @@ package proxy import ( + "bufio" "bytes" "context" "crypto" @@ -22,6 +23,7 @@ import ( "encoding/hex" "errors" "fmt" + "github.com/google/uuid" "io" "math" "math/big" @@ -82,6 +84,7 @@ type Config struct { // PreparedCache a cache that stores prepared queries. If not set it uses the default implementation with a max // capacity of ~100MB. PreparedCache proxycore.PreparedCache + EnableTracing bool } type sessionKey struct { @@ -619,18 +622,20 @@ func (c *client) Receive(reader io.Reader) error { func (c *client) execute(raw *frame.RawFrame, state idempotentState, isSelect bool, keyspace string, body *frame.Body) { if sess, err := c.proxy.findSession(raw.Header.Version, c.keyspace, c.compression); err == nil { + requestId, raw := c.handleRequestId(body, raw) req := &request{ - client: c, - session: sess, - state: state, - msg: body.Message, - keyspace: keyspace, - done: false, - stream: raw.Header.StreamId, - version: raw.Header.Version, - qp: c.proxy.newQueryPlan(), - frm: c.maybeOverrideUnsupportedWriteConsistency(isSelect, raw, body), - isSelect: isSelect, + client: c, + session: sess, + state: state, + msg: body.Message, + requestId: requestId, + keyspace: keyspace, + done: false, + stream: raw.Header.StreamId, + version: raw.Header.Version, + qp: c.proxy.newQueryPlan(), + frm: c.maybeOverrideUnsupportedWriteConsistency(isSelect, raw, body), + isSelect: isSelect, } req.Execute(true) } else { @@ -717,6 +722,40 @@ func (c *client) handleQuery(raw *frame.RawFrame, msg *codecs.PartialQuery, body } } +func (c *client) handleRequestId(body *frame.Body, frm *frame.RawFrame) ([]byte, *frame.RawFrame) { + if !c.proxy.config.EnableTracing { + return nil, frm + } + if body.CustomPayload == nil { + frm.Header.Flags = frm.Header.Flags.Add(primitive.HeaderFlagCustomPayload) + body.CustomPayload = make(map[string][]byte) + } + var reqId []byte + if id, ok := body.CustomPayload["request-id"]; !ok { + cid, err := uuid.New().MarshalBinary() + if err != nil { + return nil, frm + } + body.CustomPayload["request-id"] = cid + reqId = cid + var buffer bytes.Buffer + writer := bufio.NewWriter(&buffer) + err = c.codec.EncodeBody(frm.Header, body, writer) + if err != nil { + return nil, frm + } + err = writer.Flush() + if err != nil { + return nil, frm + } + frm.Body = buffer.Bytes() + } else { + reqId = id + } + c.proxy.logger.Info("received request id", zap.String("idHex", hex.EncodeToString(reqId))) + return reqId, frm +} + func (c *client) getDefaultIdempotency(customPayload map[string][]byte) idempotentState { state := notDetermined if _, ok := customPayload["graph-source"]; ok { // Graph queries default to non-idempotent unless overridden diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index c8b0c26..c0b3c61 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -17,11 +17,14 @@ package proxy import ( "context" "encoding/binary" + "encoding/hex" "errors" "fmt" + "go.uber.org/zap/zaptest/observer" "log" "math/big" "net" + "slices" "strconv" "sync" "testing" @@ -424,6 +427,8 @@ type proxyTestConfig struct { rpcAddr string peers []PeerConfig idempotentGraph bool + enableTracing bool + logger *zap.Logger } func setupProxyTestWithConfig(ctx context.Context, numNodes int, cfg *proxyTestConfig) (tester *proxyTester, proxyContactPoint string, err error) { @@ -439,6 +444,10 @@ func setupProxyTestWithConfig(ctx context.Context, numNodes int, cfg *proxyTestC if cfg == nil { cfg = &proxyTestConfig{} } + logger := cfg.logger + if logger == nil { + logger = zap.L() + } if cfg.handlers != nil { tester.cluster.Handlers = proxycore.NewMockRequestHandlers(cfg.handlers) @@ -461,7 +470,8 @@ func setupProxyTestWithConfig(ctx context.Context, numNodes int, cfg *proxyTestC RPCAddr: cfg.rpcAddr, Peers: cfg.peers, IdempotentGraph: cfg.idempotentGraph, - Logger: zap.L(), + Logger: logger, + EnableTracing: cfg.enableTracing, }) err = tester.proxy.Connect() @@ -537,3 +547,83 @@ func waitUntil(d time.Duration, check func() bool) bool { } return false } + +func TestProxy_RequestTracing(t *testing.T) { + var tests = []struct { + name string + reqId []byte + }{ + {name: "proxy_generated_id"}, + {name: "client_generated_id", reqId: []byte{1, 2, 3}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cfg := &proxyTestConfig{} + cfg.enableTracing = true + + // observe INFO logs produced by the proxy + core, logs := observer.New(zap.InfoLevel) + logger := zap.New(core) + cfg.logger = logger + + var reqs []frame.Frame + + cfg.handlers = proxycore.MockRequestHandlers{ + primitive.OpCodeQuery: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message { + if msg := cl.InterceptQuery(frm.Header, frm.Body.Message.(*message.Query)); msg != nil { + return msg + } else { + reqs = append(reqs, *frm) + return &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 0, + }, + Data: message.RowSet{}, + } + } + }, + } + + tester, proxyContactPoint, err := setupProxyTestWithConfig(ctx, 1, cfg) + defer func() { + cancel() + tester.shutdown() + }() + require.NoError(t, err) + + cl := connectTestClient(t, ctx, proxyContactPoint) + + query := frame.NewFrame(primitive.ProtocolVersion4, -1, &message.Query{Query: idempotentQuery}) + if tt.reqId != nil { + query.SetCustomPayload(map[string][]byte{"request-id": tt.reqId}) + } + _, err = cl.QueryFrame(ctx, query) + require.NoError(t, err) + + // assert that request IDs have been logged + logsReq := logs.FilterMessage("received request id").All() + logsRes := logs.FilterMessage("received response id").All() + traceLogs := slices.Concat(logsReq, logsRes) + assert.Equal(t, 2, len(traceLogs)) + for _, l := range traceLogs { + logCtx := l.ContextMap() + require.Contains(t, logCtx, "idHex") + if tt.reqId != nil { + assert.Equal(t, hex.EncodeToString(tt.reqId), logCtx["idHex"]) + } + } + + // assert request propagated to downstream cluster + assert.Equal(t, 1, len(reqs)) + assert.NotNil(t, reqs[0].Body.CustomPayload) + proxyReqId := reqs[0].Body.CustomPayload["request-id"] + assert.NotNil(t, proxyReqId) + if tt.reqId != nil { + assert.Equal(t, tt.reqId, proxyReqId) + } + }) + } +} diff --git a/proxy/request.go b/proxy/request.go index 3f48f36..692d25d 100644 --- a/proxy/request.go +++ b/proxy/request.go @@ -15,6 +15,7 @@ package proxy import ( + "encoding/hex" "errors" "io" "reflect" @@ -44,6 +45,7 @@ type request struct { state idempotentState keyspace string msg message.Message + requestId []byte done bool retryCount int host *proxycore.Host @@ -151,6 +153,9 @@ func (r *request) OnResult(raw *frame.RawFrame) { r.mu.Lock() defer r.mu.Unlock() if !r.done { + if r.requestId != nil { + r.client.proxy.logger.Info("received response id", zap.String("idHex", hex.EncodeToString(r.requestId))) + } if raw.Header.OpCode != primitive.OpCodeError || !r.handleErrorResult(raw) { // If the error result is retried then we don't send back this response r.client.maybeStorePreparedMetadata(raw, r.isSelect, r.msg) diff --git a/proxy/run.go b/proxy/run.go index 2678cd6..49a12bf 100644 --- a/proxy/run.go +++ b/proxy/run.go @@ -54,6 +54,7 @@ type runConfig struct { Bind string `yaml:"bind" help:"Address to use to bind server" short:"a" default:":9042" env:"BIND"` Config *os.File `yaml:"-" help:"YAML configuration file" short:"f" env:"CONFIG_FILE"` // Not available in the configuration file Debug bool `yaml:"debug" help:"Show debug logging" default:"false" env:"DEBUG"` + EnableTracing bool `yaml:"enable-tracing" help:"Enable tracing of CQL requests" default:"false" env:"ENABLE_TRACING"` HealthCheck bool `yaml:"health-check" help:"Enable liveness and readiness checks" default:"false" env:"HEALTH_CHECK"` HttpBind string `yaml:"http-bind" help:"Address to use to bind HTTP server used for health checks" default:":8000" env:"HTTP_BIND"` HeartbeatInterval time.Duration `yaml:"heartbeat-interval" help:"Interval between performing heartbeats to the cluster" default:"30s" env:"HEARTBEAT_INTERVAL"` @@ -225,6 +226,7 @@ func Run(ctx context.Context, args []string) int { IdempotentGraph: cfg.IdempotentGraph, UnsupportedWriteConsistencies: cfg.UnsupportedWriteConsistencies, UnsupportedWriteConsistencyOverride: cfg.UnsupportedWriteConsistencyOverride, + EnableTracing: cfg.EnableTracing, }) cfg.Bind = maybeAddPort(cfg.Bind, "9042")