diff --git a/pkg/object/httpserver/mux.go b/pkg/object/httpserver/mux.go index ef107b19eb..bc356de143 100644 --- a/pkg/object/httpserver/mux.go +++ b/pkg/object/httpserver/mux.go @@ -115,6 +115,9 @@ var ( forbidden = &cachedRoute{code: http.StatusForbidden} methodNotAllowed = &cachedRoute{code: http.StatusMethodNotAllowed} badRequest = &cachedRoute{code: http.StatusBadRequest} + + accessLogVarReg = regexp.MustCompile(`\{\{([a-zA-z]*)\}\}`) + accessLogEscapeReg = regexp.MustCompile(`(\[|\])`) ) func (mi *muxInstance) getRouteFromCache(req *httpprot.Request) *cachedRoute { @@ -699,10 +702,8 @@ func newAccessLogFormatter(format string) *accessLogFormatter { if format == "" { format = defaultAccessLogFormat } - varReg := regexp.MustCompile(`\{\{([a-zA-z]*)\}\}`) - expr := varReg.ReplaceAllString(format, "{{.$1}}") - escapeReg := regexp.MustCompile(`(\[|\])`) - expr = escapeReg.ReplaceAllString(expr, "{{`$1`}}") + expr := accessLogVarReg.ReplaceAllString(format, "{{.$1}}") + expr = accessLogEscapeReg.ReplaceAllString(expr, "{{`$1`}}") tpl := template.Must(template.New("").Parse(expr)) return &accessLogFormatter{template: tpl} } diff --git a/pkg/object/httpserver/mux_test.go b/pkg/object/httpserver/mux_test.go index 422f2869a3..71557cd167 100644 --- a/pkg/object/httpserver/mux_test.go +++ b/pkg/object/httpserver/mux_test.go @@ -667,3 +667,12 @@ func TestPrintHeader(t *testing.T) { t.Fail() } } + +func BenchmarkNewAccessLogFormatter(b *testing.B) { + format := "{{Method}} {{URI}} [{{ReqSize}}]" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = newAccessLogFormatter(format) + } +} diff --git a/pkg/object/mqttproxy/broker.go b/pkg/object/mqttproxy/broker.go index 525551c424..37981cf357 100644 --- a/pkg/object/mqttproxy/broker.go +++ b/pkg/object/mqttproxy/broker.go @@ -25,6 +25,7 @@ import ( "math/rand" "net" "net/http" + "sort" "strconv" "strings" "sync" @@ -62,6 +63,9 @@ type ( connectionLimiter *Limiter memberURL func(string, string) (map[string]string, error) + sessionCache map[string]*SessionInfo + sortedSessionKeys []string + // done is the channel for shutdowning this proxy. done chan struct{} closeFlag int32 @@ -202,6 +206,22 @@ func (b *Broker) connectWatcher() { logger.SpanErrorf(nil, "get all session prefix failed, %v", err) } + b.Lock() + b.sessionCache = make(map[string]*SessionInfo) + b.sortedSessionKeys = make([]string, 0, len(sessions)) + for k, v := range sessions { + session := &Session{} + if err := session.decode(v); err != nil { + logger.Warnf("ignored decode session info %s failed: %s", v, err) + continue + } + clientID := strings.TrimPrefix(k, sessionStoreKey("")) + b.sessionCache[clientID] = session.info + b.sortedSessionKeys = append(b.sortedSessionKeys, clientID) + } + sort.Strings(b.sortedSessionKeys) + b.Unlock() + if b.spec.BrokerMode { // make sessions a watcher event watcherEvent := make(map[string]*string) @@ -245,6 +265,33 @@ func (b *Broker) watch(ch <-chan map[string]*string, closeFunc func()) { } func (b *Broker) processWatcherEvent(event map[string]*string, sync bool) { + b.Lock() + cacheChanged := false + for k, v := range event { + clientID := strings.TrimPrefix(k, sessionStoreKey("")) + if v != nil { + session := &Session{} + if err := session.decode(*v); err == nil { + b.sessionCache[clientID] = session.info + cacheChanged = true + } + } else { + if _, ok := b.sessionCache[clientID]; ok { + delete(b.sessionCache, clientID) + cacheChanged = true + } + } + } + if cacheChanged { + keys := make([]string, 0, len(b.sessionCache)) + for k := range b.sessionCache { + keys = append(keys, k) + } + sort.Strings(keys) + b.sortedSessionKeys = keys + } + b.Unlock() + sessMap := make(map[string]*SessionInfo) for k, v := range event { clientID := strings.TrimPrefix(k, sessionStoreKey("")) @@ -752,15 +799,9 @@ func (b *Broker) httpGetAllSessionHandler(w http.ResponseWriter, r *http.Request } } - allSession, err := b.sessMgr.store.getPrefix(sessionStoreKey(""), false) - logger.SpanDebugf(span, "httpGetAllSessionHandler current total %v sessions, query %v, topic %v", len(allSession), []int{page, pageSize}, topic) - if err != nil { - logger.SpanErrorf(span, "get all sessions with prefix %v failed, %v", sessionStoreKey(""), err) - api.HandleAPIError(w, r, http.StatusInternalServerError, fmt.Errorf("get all sessions failed, %v", err)) - return - } - - res := b.queryAllSessions(allSession, len(query) != 0, page, pageSize, topic) + b.RLock() + res := b.queryAllSessions(len(query) != 0, page, pageSize, topic) + b.RUnlock() jsonData, err := codectool.MarshalJSON(res) if err != nil { @@ -775,41 +816,62 @@ func (b *Broker) httpGetAllSessionHandler(w http.ResponseWriter, r *http.Request } } -func (b *Broker) queryAllSessions(allSession map[string]string, query bool, page, pageSize int, topic string) *HTTPSessions { +func (b *Broker) queryAllSessions(query bool, page, pageSize int, topic string) *HTTPSessions { res := &HTTPSessions{} if !query { - for k := range allSession { + for _, k := range b.sortedSessionKeys { httpSession := &HTTPSession{ - SessionID: strings.TrimPrefix(k, sessionStoreKey("")), + SessionID: k, } res.Sessions = append(res.Sessions, httpSession) } return res } - index := 0 start := page*pageSize - pageSize end := page * pageSize - for _, v := range allSession { - if index >= start && index < end { - session := &Session{} - session.info = &SessionInfo{} - session.decode(v) - for k := range session.info.Topics { + total := len(b.sortedSessionKeys) + + if start >= total { + return res + } + if end > total { + end = total + } + + candidates := b.sortedSessionKeys[start:end] + for _, clientID := range candidates { + info, ok := b.sessionCache[clientID] + if !ok { + continue + } + + // If topic is specified, we filter by it. + // Note: The logic here follows the original implementation where pagination + // happens BEFORE filtering. This means we take the page of sessions, + // and THEN check if they match the topic. + matched := false + matchedTopic := "" + + if topic == "" { + matched = true + } else { + for k := range info.Topics { if strings.Contains(k, topic) { - httpSession := &HTTPSession{ - SessionID: session.info.ClientID, - Topic: k, - } - res.Sessions = append(res.Sessions, httpSession) + matched = true + matchedTopic = k break } } } - if index > end { - break + + if matched { + httpSession := &HTTPSession{ + SessionID: info.ClientID, + Topic: matchedTopic, + } + res.Sessions = append(res.Sessions, httpSession) } - index++ } return res } diff --git a/pkg/object/mqttproxy/session.go b/pkg/object/mqttproxy/session.go index d5b64113b2..81982adcb2 100644 --- a/pkg/object/mqttproxy/session.go +++ b/pkg/object/mqttproxy/session.go @@ -296,16 +296,16 @@ func (s *Session) backgroundSessionTask() { select { case <-s.done: return - case <-ticker.C: + case t := <-ticker.C: s.store() - if time.Now().After(resendTime) { + if t.After(resendTime) { s.doResend() - resendTime = time.Now().Add(s.retryInterval) + resendTime = t.Add(s.retryInterval) + } + if t.After(debugLogTime) { + logger.SpanDebugf(nil, "session %v resend", s.info.ClientID) + debugLogTime = t.Add(time.Minute) } - } - if time.Now().After(debugLogTime) { - logger.SpanDebugf(nil, "session %v resend", s.info.ClientID) - debugLogTime = time.Now().Add(time.Minute) } } }