Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api/handler/internal_service_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/gin-gonic/gin"
"opencsg.com/csghub-server/builder/proxy"
"opencsg.com/csghub-server/common/utils/trace"
)

type InternalServiceProxyHandler struct {
Expand All @@ -25,6 +26,9 @@ func (h *InternalServiceProxyHandler) Proxy(ctx *gin.Context) {
// Log the request URL and header
slog.Debug("http request", slog.Any("request", ctx.Request.URL), slog.Any("header", ctx.Request.Header))

// Propagate trace ID to backend service
trace.PropagateTrace(ctx.Request.Context(), ctx.Request.Header)

// Serve the request using the router
h.rp.ServeHTTP(ctx.Writer, ctx.Request, "", "")
}
Expand All @@ -37,6 +41,10 @@ func (h *InternalServiceProxyHandler) Proxy(ctx *gin.Context) {
func (h *InternalServiceProxyHandler) ProxyToApi(api string, originParams ...string) gin.HandlerFunc {
return func(ctx *gin.Context) {
slog.Info("proxy user request", slog.Any("request", ctx.Request.URL), slog.Any("header", ctx.Request.Header))

// Propagate trace ID to backend service
trace.PropagateTrace(ctx.Request.Context(), ctx.Request.Header)

finalApi := api
if len(originParams) > 0 {
var params []any
Expand Down
2 changes: 0 additions & 2 deletions api/middleware/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ func Log() gin.HandlerFunc {
slog.Any("auth_type", httpbase.GetAuthType(ctx)),
slog.String("url", ctx.Request.URL.RequestURI()),
slog.String("full_path", ctx.FullPath()),
slog.String("req_header_range", ctx.GetHeader("Range")),
slog.String("rsp_content_range", ctx.Writer.Header().Get("Content-Range")),
)
}
}
4 changes: 4 additions & 0 deletions builder/proxy/reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/url"

"github.com/openai/openai-go/v3"
"opencsg.com/csghub-server/common/utils/trace"
)

type ReverseProxy interface {
Expand Down Expand Up @@ -74,6 +75,9 @@ func (rp *reverseProxyImpl) ServeHTTP(w http.ResponseWriter, r *http.Request, ap
resp.Header.Del("Access-Control-Allow-Headers")
resp.Header.Del("Access-Control-Allow-Methods")
resp.Header.Del("Access-Control-Allow-Origin")
// remove duplicate X-Request-Id header from downstream response
// because it is already set by the gateway middleware
resp.Header.Del(trace.HeaderRequestID)

return nil
}
Expand Down
26 changes: 4 additions & 22 deletions builder/rpc/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ import (
slogmulti "github.com/samber/slog-multi"
"go.opentelemetry.io/contrib/bridges/otelslog"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"opencsg.com/csghub-server/api/httpbase"
"opencsg.com/csghub-server/common/config"
"opencsg.com/csghub-server/common/errorx"
)

type HttpDoer interface {
Expand Down Expand Up @@ -112,7 +110,7 @@ func (c *HttpClient) Get(ctx context.Context, path string, outObj interface{}) e
fullPath := fmt.Sprintf("%s%s", c.endpoint, path)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullPath, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", errorx.ErrInternalServerError)
return fmt.Errorf("failed to create request: %w", err)
}
for _, opt := range c.authOpts {
opt.Set(req)
Expand All @@ -125,17 +123,11 @@ func (c *HttpClient) Get(ctx context.Context, path string, outObj interface{}) e
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
var errResp httpbase.R
jsonErr := json.NewDecoder(resp.Body).Decode(&errResp)
if jsonErr == nil {
customErr := errorx.ParseError(errResp.Msg, errorx.ErrRemoteServiceFail, errResp.Context)
return customErr
}
return fmt.Errorf("failed to get response, path:%s, status:%d", path, resp.StatusCode)
}
err = json.NewDecoder(resp.Body).Decode(outObj)
if err != nil {
return fmt.Errorf("failed to decode resp body in HttpClient.Get, err:%w", errorx.ErrInternalServerError)
return fmt.Errorf("failed to decode resp body in HttpClient.Get, err:%w", err)
}
return nil
}
Expand Down Expand Up @@ -170,19 +162,13 @@ func (c *HttpClient) Post(ctx context.Context, path string, data interface{}, ou
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
var errResp httpbase.R
jsonErr := json.NewDecoder(resp.Body).Decode(&errResp)
if jsonErr == nil {
customErr := errorx.ParseError(errResp.Msg, errorx.ErrRemoteServiceFail, errResp.Context)
return customErr
}
return fmt.Errorf("failed to get response, path:%s, status:%d", path, resp.StatusCode)
}

if outObj != nil {
err = json.NewDecoder(resp.Body).Decode(outObj)
if err != nil {
return fmt.Errorf("failed to decode resp body in HttpClient.Post, err:%w", errorx.ErrInternalServerError)
return fmt.Errorf("failed to decode resp body in HttpClient.Post, err:%w", err)
}
}

Expand All @@ -192,11 +178,7 @@ func (c *HttpClient) Post(ctx context.Context, path string, data interface{}, ou
func (c *HttpClient) Do(req *http.Request) (resp *http.Response, err error) {
ctx := req.Context()
fullPath := req.URL.String()
traceID, traceParent, _ := trace.GetOrGenTraceIDFromContext(ctx)
if traceParent != "" {
req.Header.Set(trace.HeaderTraceparent, traceParent)
}

traceID := trace.PropagateTrace(ctx, req.Header)
startTime := time.Now()
retryTime := time.Now()
err = retry.Do(
Expand Down
13 changes: 13 additions & 0 deletions common/utils/trace/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
"strings"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -189,3 +190,15 @@ func GetOrGenTraceIDFromContext(ctx context.Context) (traceID, traceParent strin

return traceID, traceParent, true
}

// PropagateTrace propagates the trace ID and traceparent to the http header.
func PropagateTrace(ctx context.Context, header http.Header) string {
traceID, traceParent, _ := GetOrGenTraceIDFromContext(ctx)
if traceParent != "" {
header.Set(HeaderTraceparent, traceParent)
}
if traceID != "" {
header.Set(HeaderRequestID, traceID)
}
return traceID
}
Loading