Skip to content
Merged
228 changes: 207 additions & 21 deletions ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -6037,7 +6037,7 @@ func RunActionAI(resp http.ResponseWriter, request *http.Request) {
}

// Indicates to output an action, and the input data could be a large blob
if len(input.Query) > 300 && !strings.Contains(input.OutputFormat, "action") && !strings.Contains(input.OutputFormat, "formatting") {
if len(input.Query) > 4000 && !strings.Contains(input.OutputFormat, "action") && !strings.Contains(input.OutputFormat, "formatting") {
resp.WriteHeader(400)
resp.Write([]byte(`{"success": false, "reason": "Max input length exceeded."}`))
return
Expand Down Expand Up @@ -6338,7 +6338,7 @@ func getWorkflowSuggestionAIResponse(ctx context.Context, resp http.ResponseWrit
func getSupportSuggestionAIResponse(ctx context.Context, resp http.ResponseWriter, user User, org Org, outputFormat string, input QueryInput) {
log.Printf("[INFO] Getting support suggestion for query: %s for org: %s", input.Query, org.Id)
// reply := runSupportRequest(ctx, input)
reply, threadId, err := runSupportLLMAssistant(ctx, input)
reply, threadId, err := runSupportLLMAssistant(ctx, input, user)
if err != nil {
log.Printf("[ERROR] Failed to run support LLM assistant: %s", err)
resp.WriteHeader(501)
Expand Down Expand Up @@ -10758,7 +10758,7 @@ func HandleEditWorkflowWithLLM(resp http.ResponseWriter, request *http.Request)
resp.Write(workflowJson)
}

func runSupportLLMAssistant(ctx context.Context, input QueryInput) (string, string, error) {
func runSupportLLMAssistant(ctx context.Context, input QueryInput, user User) (string, string, error) {

apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" || assistantId == "" || docsVectorStoreID == "" {
Expand All @@ -10778,36 +10778,29 @@ func runSupportLLMAssistant(ctx context.Context, input QueryInput) (string, stri
isValidThread := false

if strings.TrimSpace(input.ThreadId) != "" {
log.Printf("[DEBUG] Checking existing thread for the org %s", input.OrgId)
cacheKey := fmt.Sprintf("support_assistant_thread_%s", input.ThreadId)
cachedData, err := GetCache(ctx, cacheKey)
if err != nil {
log.Printf("[WARNING] Failed to get cache for thread %s: %s", threadID, err)
}

if cachedData != nil {
if err != nil {
// Thread not found in cache - will create new thread
} else if cachedData != nil {
orgId := ""
if byteSlice, ok := cachedData.([]byte); ok {
orgId = string(byteSlice)
}

if len(orgId) > 0 && orgId == input.OrgId {
log.Printf("[INFO] Found existing thread %s for org %s", input.ThreadId, input.OrgId)
threadID = input.ThreadId
isValidThread = true
value := []byte(input.OrgId)
// Refresh the cache TTL
err = SetCache(ctx, cacheKey, value, 1440)
if err != nil {
log.Printf("[WARNING] Failed to refresh cache for thread %s: %s", threadID, err)
if len(orgId) > 0 {
if orgId == input.OrgId {
threadID = input.ThreadId
isValidThread = true
} else {
return "", "", errors.New("thread belongs to different organization")
}

}
}
}

if isValidThread {
log.Printf("[DEBUG] Adding new message to existing thread %s", threadID)
_, err := client.CreateMessage(
ctx,
threadID,
Expand Down Expand Up @@ -10844,8 +10837,7 @@ func runSupportLLMAssistant(ctx context.Context, input QueryInput) (string, stri
cacheKey := fmt.Sprintf("support_assistant_thread_%s", threadID)
value := []byte(input.OrgId)

// Cache the thread ID for future use
err = SetCache(ctx, cacheKey, value, 1440)
err = SetCache(ctx, cacheKey, value, 86400)
if err != nil {
log.Printf("[WARNING] Failed to set cache for thread %s: %s", threadID, err)
}
Expand Down Expand Up @@ -10962,3 +10954,197 @@ Based on these rules and the provided documents, please answer the question:`
}
}
}

func getSupportThreadConversation(ctx context.Context, threadID string, user User) (ThreadConversationResponse, error) {
response := ThreadConversationResponse{
Success: false,
ThreadID: threadID,
Messages: []ConversationMessage{},
}

threadOrgID := ""
cacheKey := fmt.Sprintf("support_assistant_thread_%s", threadID)

if user.SupportAccess {
cachedData, err := GetCache(ctx, cacheKey)
if err == nil && cachedData != nil {
if byteSlice, ok := cachedData.([]byte); ok {
threadOrgID = string(byteSlice)
}
}
response.ThreadOrgID = threadOrgID
if user.ActiveOrg.Id == threadOrgID {
response.IsActiveOrg = true
}
} else {
cachedData, err := GetCache(ctx, cacheKey)
if err != nil || cachedData == nil {
log.Printf("[WARNING] Thread %s not found for user %s", threadID, user.Username)
return response, errors.New("thread not found or access denied")
}

byteSlice, ok := cachedData.([]byte)
if !ok {
log.Printf("[ERROR] Invalid cache data for thread %s", threadID)
return response, errors.New("thread not found or access denied")
}
threadOrgID = string(byteSlice)

userInOrg := false
for _, orgID := range user.Orgs {
if orgID == threadOrgID {
userInOrg = true
break
}
}

if !userInOrg {
log.Printf("[WARNING] User %s unauthorized for thread %s (org: %s)", user.Username, threadID, threadOrgID)
return response, errors.New("unauthorized: user not member of thread organization")
}

response.ThreadOrgID = threadOrgID
if user.ActiveOrg.Id == threadOrgID {
response.IsActiveOrg = true
}
}

apiKey := os.Getenv("AI_API_KEY")
if apiKey == "" {
apiKey = os.Getenv("OPENAI_API_KEY")
}
if apiKey == "" {
return response, errors.New("OPENAI_API_KEY must be set")
}

config := openai.DefaultConfig(apiKey)
config.AssistantVersion = "v2"
client := openai.NewClientWithConfig(config)

limit := 100
order := "asc"
messages, err := client.ListMessage(ctx, threadID, &limit, &order, nil, nil, nil)
if err != nil {
log.Printf("[ERROR] Failed to get messages for thread %s: %s", threadID, err)
return response, fmt.Errorf("failed to retrieve thread messages: %w", err)
}

conversationMessages := make([]ConversationMessage, 0, len(messages.Messages))
for _, message := range messages.Messages {
if len(message.Content) > 0 && message.Content[0].Type == "text" && message.Content[0].Text != nil {
cleanContent := message.Content[0].Text.Value
re := regexp.MustCompile(`【.*?】`)
cleanContent = re.ReplaceAllString(cleanContent, "")

conversationMessages = append(conversationMessages, ConversationMessage{
Role: string(message.Role),
Content: cleanContent,
Timestamp: time.Unix(int64(message.CreatedAt), 0),
})
}
}

response.Success = true
response.Messages = conversationMessages
return response, nil
}

func HandleGetSupportThreadConversation(resp http.ResponseWriter, request *http.Request) {
cors := HandleCors(resp, request)
if cors {
return
}

ctx := GetContext(request)
user, err := HandleApiAuthentication(resp, request)
if err != nil {
log.Printf("[AUDIT] Api authentication failed in get support thread conversation: %s", err)
resp.WriteHeader(401)
resp.Write([]byte(`{"success": false, "message": "Authentication failed"}`))
return
}

body, err := ioutil.ReadAll(request.Body)
if err != nil {
log.Printf("[WARNING] Failed to read request body in get support thread conversation: %s", err)
resp.WriteHeader(400)
resp.Write([]byte(`{"success": false, "message": "Failed to read request body"}`))
return
}

var threadRequest ThreadAccessRequest
err = json.Unmarshal(body, &threadRequest)
if err != nil {
log.Printf("[WARNING] Failed to unmarshal thread request in get support thread conversation: %s", err)
resp.WriteHeader(400)
resp.Write([]byte(`{"success": false, "message": "Invalid request format"}`))
return
}

if strings.TrimSpace(threadRequest.ThreadID) == "" {
resp.WriteHeader(400)
resp.Write([]byte(`{"success": false, "message": "Thread ID is required"}`))
return
}

log.Printf("[INFO] Getting thread conversation for thread %s by user %s (%s)", threadRequest.ThreadID, user.Username, user.Id)

response, err := getSupportThreadConversation(ctx, threadRequest.ThreadID, user)
if err != nil {
log.Printf("[WARNING] Failed to get thread conversation for thread %s by user %s: %s", threadRequest.ThreadID, user.Username, err)

output, marshalErr := json.Marshal(response)
if marshalErr != nil {
log.Printf("[ERROR] Failed to marshal error response: %s", marshalErr)
resp.WriteHeader(500)
resp.Write([]byte(`{"success": false, "message": "Internal server error"}`))
return
}

if strings.Contains(err.Error(), "unauthorized") || strings.Contains(err.Error(), "access denied") {
resp.WriteHeader(403)
} else if strings.Contains(err.Error(), "not found") {
resp.WriteHeader(404)
} else {
resp.WriteHeader(500)
}

resp.Write(output)
return
}

output, err := json.Marshal(response)
if err != nil {
log.Printf("[ERROR] Failed to marshal response for thread %s: %s", threadRequest.ThreadID, err)
resp.WriteHeader(500)
resp.Write([]byte(`{"success": false, "message": "Failed to marshal response"}`))
return
}

log.Printf("[INFO] Successfully retrieved %d messages for thread %s for user %s", len(response.Messages), threadRequest.ThreadID, user.Username)
resp.WriteHeader(200)
resp.Write(output)
}

func validateChatContext(ctx context.Context, threadID string, user User) error {
if user.SupportAccess {
return nil
}

cacheKey := fmt.Sprintf("support_assistant_thread_%s", threadID)
cachedData, err := GetCache(ctx, cacheKey)
if err != nil {
return errors.New("thread not found")
}

if cachedData != nil {
if byteSlice, ok := cachedData.([]byte); ok {
threadOrgID := string(byteSlice)
if threadOrgID != user.ActiveOrg.Id {
return fmt.Errorf("cannot send message: thread belongs to different organization. Please switch to the correct organization first")
}
}
}

return nil
}
19 changes: 19 additions & 0 deletions structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4747,3 +4747,22 @@ type AuditLogCollector struct {
StopChan chan bool
mu sync.Mutex
}

// Thread conversation access control structs
type ThreadAccessRequest struct {
ThreadID string `json:"thread_id"`
}

type ThreadConversationResponse struct {
Success bool `json:"success"`
ThreadID string `json:"thread_id"`
ThreadOrgID string `json:"thread_org_id"` // Org where thread lives (for switching orgs)
Messages []ConversationMessage `json:"messages"`
IsActiveOrg bool `json:"is_active_org"` // Whether this is user's active org
}

type ConversationMessage struct {
Role string `json:"role"` // "user" or "assistant"
Content string `json:"content"`
Timestamp time.Time `json:"timestamp"`
}
Loading