diff --git a/ai.go b/ai.go index 4fd14c13..9cb52906 100644 --- a/ai.go +++ b/ai.go @@ -6037,7 +6037,7 @@ func RunActionAI(resp http.ResponseWriter, request *http.Request) { } // Indicates to output an action, and the input data could be a large blob - if len(input.Query) > 300 && !strings.Contains(input.OutputFormat, "action") && !strings.Contains(input.OutputFormat, "formatting") { + if len(input.Query) > 4000 && !strings.Contains(input.OutputFormat, "action") && !strings.Contains(input.OutputFormat, "formatting") { resp.WriteHeader(400) resp.Write([]byte(`{"success": false, "reason": "Max input length exceeded."}`)) return @@ -6338,7 +6338,7 @@ func getWorkflowSuggestionAIResponse(ctx context.Context, resp http.ResponseWrit func getSupportSuggestionAIResponse(ctx context.Context, resp http.ResponseWriter, user User, org Org, outputFormat string, input QueryInput) { log.Printf("[INFO] Getting support suggestion for query: %s for org: %s", input.Query, org.Id) // reply := runSupportRequest(ctx, input) - reply, threadId, err := runSupportLLMAssistant(ctx, input) + reply, threadId, err := runSupportLLMAssistant(ctx, input, user) if err != nil { log.Printf("[ERROR] Failed to run support LLM assistant: %s", err) resp.WriteHeader(501) @@ -10758,7 +10758,7 @@ func HandleEditWorkflowWithLLM(resp http.ResponseWriter, request *http.Request) resp.Write(workflowJson) } -func runSupportLLMAssistant(ctx context.Context, input QueryInput) (string, string, error) { +func runSupportLLMAssistant(ctx context.Context, input QueryInput, user User) (string, string, error) { apiKey := os.Getenv("OPENAI_API_KEY") if apiKey == "" || assistantId == "" || docsVectorStoreID == "" { @@ -10778,36 +10778,29 @@ func runSupportLLMAssistant(ctx context.Context, input QueryInput) (string, stri isValidThread := false if strings.TrimSpace(input.ThreadId) != "" { - log.Printf("[DEBUG] Checking existing thread for the org %s", input.OrgId) cacheKey := fmt.Sprintf("support_assistant_thread_%s", input.ThreadId) cachedData, err := GetCache(ctx, cacheKey) - if err != nil { - log.Printf("[WARNING] Failed to get cache for thread %s: %s", threadID, err) - } - if cachedData != nil { + if err != nil { + // Thread not found in cache - will create new thread + } else if cachedData != nil { orgId := "" if byteSlice, ok := cachedData.([]byte); ok { orgId = string(byteSlice) } - if len(orgId) > 0 && orgId == input.OrgId { - log.Printf("[INFO] Found existing thread %s for org %s", input.ThreadId, input.OrgId) - threadID = input.ThreadId - isValidThread = true - value := []byte(input.OrgId) - // Refresh the cache TTL - err = SetCache(ctx, cacheKey, value, 1440) - if err != nil { - log.Printf("[WARNING] Failed to refresh cache for thread %s: %s", threadID, err) + if len(orgId) > 0 { + if orgId == input.OrgId { + threadID = input.ThreadId + isValidThread = true + } else { + return "", "", errors.New("thread belongs to different organization") } - } } } if isValidThread { - log.Printf("[DEBUG] Adding new message to existing thread %s", threadID) _, err := client.CreateMessage( ctx, threadID, @@ -10844,8 +10837,7 @@ func runSupportLLMAssistant(ctx context.Context, input QueryInput) (string, stri cacheKey := fmt.Sprintf("support_assistant_thread_%s", threadID) value := []byte(input.OrgId) - // Cache the thread ID for future use - err = SetCache(ctx, cacheKey, value, 1440) + err = SetCache(ctx, cacheKey, value, 86400) if err != nil { log.Printf("[WARNING] Failed to set cache for thread %s: %s", threadID, err) } @@ -10962,3 +10954,197 @@ Based on these rules and the provided documents, please answer the question:` } } } + +func getSupportThreadConversation(ctx context.Context, threadID string, user User) (ThreadConversationResponse, error) { + response := ThreadConversationResponse{ + Success: false, + ThreadID: threadID, + Messages: []ConversationMessage{}, + } + + threadOrgID := "" + cacheKey := fmt.Sprintf("support_assistant_thread_%s", threadID) + + if user.SupportAccess { + cachedData, err := GetCache(ctx, cacheKey) + if err == nil && cachedData != nil { + if byteSlice, ok := cachedData.([]byte); ok { + threadOrgID = string(byteSlice) + } + } + response.ThreadOrgID = threadOrgID + if user.ActiveOrg.Id == threadOrgID { + response.IsActiveOrg = true + } + } else { + cachedData, err := GetCache(ctx, cacheKey) + if err != nil || cachedData == nil { + log.Printf("[WARNING] Thread %s not found for user %s", threadID, user.Username) + return response, errors.New("thread not found or access denied") + } + + byteSlice, ok := cachedData.([]byte) + if !ok { + log.Printf("[ERROR] Invalid cache data for thread %s", threadID) + return response, errors.New("thread not found or access denied") + } + threadOrgID = string(byteSlice) + + userInOrg := false + for _, orgID := range user.Orgs { + if orgID == threadOrgID { + userInOrg = true + break + } + } + + if !userInOrg { + log.Printf("[WARNING] User %s unauthorized for thread %s (org: %s)", user.Username, threadID, threadOrgID) + return response, errors.New("unauthorized: user not member of thread organization") + } + + response.ThreadOrgID = threadOrgID + if user.ActiveOrg.Id == threadOrgID { + response.IsActiveOrg = true + } + } + + apiKey := os.Getenv("AI_API_KEY") + if apiKey == "" { + apiKey = os.Getenv("OPENAI_API_KEY") + } + if apiKey == "" { + return response, errors.New("OPENAI_API_KEY must be set") + } + + config := openai.DefaultConfig(apiKey) + config.AssistantVersion = "v2" + client := openai.NewClientWithConfig(config) + + limit := 100 + order := "asc" + messages, err := client.ListMessage(ctx, threadID, &limit, &order, nil, nil, nil) + if err != nil { + log.Printf("[ERROR] Failed to get messages for thread %s: %s", threadID, err) + return response, fmt.Errorf("failed to retrieve thread messages: %w", err) + } + + conversationMessages := make([]ConversationMessage, 0, len(messages.Messages)) + for _, message := range messages.Messages { + if len(message.Content) > 0 && message.Content[0].Type == "text" && message.Content[0].Text != nil { + cleanContent := message.Content[0].Text.Value + re := regexp.MustCompile(`【.*?】`) + cleanContent = re.ReplaceAllString(cleanContent, "") + + conversationMessages = append(conversationMessages, ConversationMessage{ + Role: string(message.Role), + Content: cleanContent, + Timestamp: time.Unix(int64(message.CreatedAt), 0), + }) + } + } + + response.Success = true + response.Messages = conversationMessages + return response, nil +} + +func HandleGetSupportThreadConversation(resp http.ResponseWriter, request *http.Request) { + cors := HandleCors(resp, request) + if cors { + return + } + + ctx := GetContext(request) + user, err := HandleApiAuthentication(resp, request) + if err != nil { + log.Printf("[AUDIT] Api authentication failed in get support thread conversation: %s", err) + resp.WriteHeader(401) + resp.Write([]byte(`{"success": false, "message": "Authentication failed"}`)) + return + } + + body, err := ioutil.ReadAll(request.Body) + if err != nil { + log.Printf("[WARNING] Failed to read request body in get support thread conversation: %s", err) + resp.WriteHeader(400) + resp.Write([]byte(`{"success": false, "message": "Failed to read request body"}`)) + return + } + + var threadRequest ThreadAccessRequest + err = json.Unmarshal(body, &threadRequest) + if err != nil { + log.Printf("[WARNING] Failed to unmarshal thread request in get support thread conversation: %s", err) + resp.WriteHeader(400) + resp.Write([]byte(`{"success": false, "message": "Invalid request format"}`)) + return + } + + if strings.TrimSpace(threadRequest.ThreadID) == "" { + resp.WriteHeader(400) + resp.Write([]byte(`{"success": false, "message": "Thread ID is required"}`)) + return + } + + log.Printf("[INFO] Getting thread conversation for thread %s by user %s (%s)", threadRequest.ThreadID, user.Username, user.Id) + + response, err := getSupportThreadConversation(ctx, threadRequest.ThreadID, user) + if err != nil { + log.Printf("[WARNING] Failed to get thread conversation for thread %s by user %s: %s", threadRequest.ThreadID, user.Username, err) + + output, marshalErr := json.Marshal(response) + if marshalErr != nil { + log.Printf("[ERROR] Failed to marshal error response: %s", marshalErr) + resp.WriteHeader(500) + resp.Write([]byte(`{"success": false, "message": "Internal server error"}`)) + return + } + + if strings.Contains(err.Error(), "unauthorized") || strings.Contains(err.Error(), "access denied") { + resp.WriteHeader(403) + } else if strings.Contains(err.Error(), "not found") { + resp.WriteHeader(404) + } else { + resp.WriteHeader(500) + } + + resp.Write(output) + return + } + + output, err := json.Marshal(response) + if err != nil { + log.Printf("[ERROR] Failed to marshal response for thread %s: %s", threadRequest.ThreadID, err) + resp.WriteHeader(500) + resp.Write([]byte(`{"success": false, "message": "Failed to marshal response"}`)) + return + } + + log.Printf("[INFO] Successfully retrieved %d messages for thread %s for user %s", len(response.Messages), threadRequest.ThreadID, user.Username) + resp.WriteHeader(200) + resp.Write(output) +} + +func validateChatContext(ctx context.Context, threadID string, user User) error { + if user.SupportAccess { + return nil + } + + cacheKey := fmt.Sprintf("support_assistant_thread_%s", threadID) + cachedData, err := GetCache(ctx, cacheKey) + if err != nil { + return errors.New("thread not found") + } + + if cachedData != nil { + if byteSlice, ok := cachedData.([]byte); ok { + threadOrgID := string(byteSlice) + if threadOrgID != user.ActiveOrg.Id { + return fmt.Errorf("cannot send message: thread belongs to different organization. Please switch to the correct organization first") + } + } + } + + return nil +} diff --git a/structs.go b/structs.go index 311bd240..c9f98cf1 100755 --- a/structs.go +++ b/structs.go @@ -4747,3 +4747,22 @@ type AuditLogCollector struct { StopChan chan bool mu sync.Mutex } + +// Thread conversation access control structs +type ThreadAccessRequest struct { + ThreadID string `json:"thread_id"` +} + +type ThreadConversationResponse struct { + Success bool `json:"success"` + ThreadID string `json:"thread_id"` + ThreadOrgID string `json:"thread_org_id"` // Org where thread lives (for switching orgs) + Messages []ConversationMessage `json:"messages"` + IsActiveOrg bool `json:"is_active_org"` // Whether this is user's active org +} + +type ConversationMessage struct { + Role string `json:"role"` // "user" or "assistant" + Content string `json:"content"` + Timestamp time.Time `json:"timestamp"` +}