diff --git a/api/handler/internal_service_proxy.go b/api/handler/internal_service_proxy.go index 80693521e..f8520ceeb 100644 --- a/api/handler/internal_service_proxy.go +++ b/api/handler/internal_service_proxy.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "opencsg.com/csghub-server/builder/proxy" + "opencsg.com/csghub-server/common/utils/trace" ) type InternalServiceProxyHandler struct { @@ -25,6 +26,9 @@ func (h *InternalServiceProxyHandler) Proxy(ctx *gin.Context) { // Log the request URL and header slog.Debug("http request", slog.Any("request", ctx.Request.URL), slog.Any("header", ctx.Request.Header)) + // Propagate trace ID to backend service + trace.PropagateTrace(ctx.Request.Context(), ctx.Request.Header) + // Serve the request using the router h.rp.ServeHTTP(ctx.Writer, ctx.Request, "", "") } @@ -37,6 +41,10 @@ func (h *InternalServiceProxyHandler) Proxy(ctx *gin.Context) { func (h *InternalServiceProxyHandler) ProxyToApi(api string, originParams ...string) gin.HandlerFunc { return func(ctx *gin.Context) { slog.Info("proxy user request", slog.Any("request", ctx.Request.URL), slog.Any("header", ctx.Request.Header)) + + // Propagate trace ID to backend service + trace.PropagateTrace(ctx.Request.Context(), ctx.Request.Header) + finalApi := api if len(originParams) > 0 { var params []any diff --git a/api/middleware/log.go b/api/middleware/log.go index 7f4ab09bb..dc6ea108b 100644 --- a/api/middleware/log.go +++ b/api/middleware/log.go @@ -30,8 +30,6 @@ func Log() gin.HandlerFunc { slog.Any("auth_type", httpbase.GetAuthType(ctx)), slog.String("url", ctx.Request.URL.RequestURI()), slog.String("full_path", ctx.FullPath()), - slog.String("req_header_range", ctx.GetHeader("Range")), - slog.String("rsp_content_range", ctx.Writer.Header().Get("Content-Range")), ) } } diff --git a/builder/proxy/reverse_proxy.go b/builder/proxy/reverse_proxy.go index 6b5281693..b91af2bb4 100644 --- a/builder/proxy/reverse_proxy.go +++ b/builder/proxy/reverse_proxy.go @@ -7,6 +7,7 @@ import ( "net/url" "github.com/openai/openai-go/v3" + "opencsg.com/csghub-server/common/utils/trace" ) type ReverseProxy interface { @@ -74,6 +75,9 @@ func (rp *reverseProxyImpl) ServeHTTP(w http.ResponseWriter, r *http.Request, ap resp.Header.Del("Access-Control-Allow-Headers") resp.Header.Del("Access-Control-Allow-Methods") resp.Header.Del("Access-Control-Allow-Origin") + // remove duplicate X-Request-Id header from downstream response + // because it is already set by the gateway middleware + resp.Header.Del(trace.HeaderRequestID) return nil } diff --git a/builder/rpc/http_client.go b/builder/rpc/http_client.go index 63b7e3575..263eb1d2f 100644 --- a/builder/rpc/http_client.go +++ b/builder/rpc/http_client.go @@ -19,9 +19,7 @@ import ( slogmulti "github.com/samber/slog-multi" "go.opentelemetry.io/contrib/bridges/otelslog" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" - "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/common/errorx" ) type HttpDoer interface { @@ -112,7 +110,7 @@ func (c *HttpClient) Get(ctx context.Context, path string, outObj interface{}) e fullPath := fmt.Sprintf("%s%s", c.endpoint, path) req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullPath, nil) if err != nil { - return fmt.Errorf("failed to create request: %w", errorx.ErrInternalServerError) + return fmt.Errorf("failed to create request: %w", err) } for _, opt := range c.authOpts { opt.Set(req) @@ -125,17 +123,11 @@ func (c *HttpClient) Get(ctx context.Context, path string, outObj interface{}) e defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - var errResp httpbase.R - jsonErr := json.NewDecoder(resp.Body).Decode(&errResp) - if jsonErr == nil { - customErr := errorx.ParseError(errResp.Msg, errorx.ErrRemoteServiceFail, errResp.Context) - return customErr - } return fmt.Errorf("failed to get response, path:%s, status:%d", path, resp.StatusCode) } err = json.NewDecoder(resp.Body).Decode(outObj) if err != nil { - return fmt.Errorf("failed to decode resp body in HttpClient.Get, err:%w", errorx.ErrInternalServerError) + return fmt.Errorf("failed to decode resp body in HttpClient.Get, err:%w", err) } return nil } @@ -170,19 +162,13 @@ func (c *HttpClient) Post(ctx context.Context, path string, data interface{}, ou defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - var errResp httpbase.R - jsonErr := json.NewDecoder(resp.Body).Decode(&errResp) - if jsonErr == nil { - customErr := errorx.ParseError(errResp.Msg, errorx.ErrRemoteServiceFail, errResp.Context) - return customErr - } return fmt.Errorf("failed to get response, path:%s, status:%d", path, resp.StatusCode) } if outObj != nil { err = json.NewDecoder(resp.Body).Decode(outObj) if err != nil { - return fmt.Errorf("failed to decode resp body in HttpClient.Post, err:%w", errorx.ErrInternalServerError) + return fmt.Errorf("failed to decode resp body in HttpClient.Post, err:%w", err) } } @@ -192,11 +178,7 @@ func (c *HttpClient) Post(ctx context.Context, path string, data interface{}, ou func (c *HttpClient) Do(req *http.Request) (resp *http.Response, err error) { ctx := req.Context() fullPath := req.URL.String() - traceID, traceParent, _ := trace.GetOrGenTraceIDFromContext(ctx) - if traceParent != "" { - req.Header.Set(trace.HeaderTraceparent, traceParent) - } - + traceID := trace.PropagateTrace(ctx, req.Header) startTime := time.Now() retryTime := time.Now() err = retry.Do( diff --git a/common/utils/trace/trace.go b/common/utils/trace/trace.go index 323907b6e..887e374bc 100644 --- a/common/utils/trace/trace.go +++ b/common/utils/trace/trace.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net/http" "strings" "github.com/gin-gonic/gin" @@ -189,3 +190,15 @@ func GetOrGenTraceIDFromContext(ctx context.Context) (traceID, traceParent strin return traceID, traceParent, true } + +// PropagateTrace propagates the trace ID and traceparent to the http header. +func PropagateTrace(ctx context.Context, header http.Header) string { + traceID, traceParent, _ := GetOrGenTraceIDFromContext(ctx) + if traceParent != "" { + header.Set(HeaderTraceparent, traceParent) + } + if traceID != "" { + header.Set(HeaderRequestID, traceID) + } + return traceID +}