Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,10 @@ public ResponseEntity<RsData<ChatResponseDto>> sendMessage(@Valid @RequestBody C
}
}

@GetMapping("/history/{sessionId}")
public ResponseEntity<RsData<List<ChatConversation>>> getChatHistory(@PathVariable String sessionId) {
@GetMapping("/history/user/{userId}")
public ResponseEntity<RsData<List<ChatConversation>>> getUserChatHistory(@PathVariable Long userId) {
try {
List<ChatConversation> history = chatbotService.getChatHistory(sessionId);
return ResponseEntity.ok(RsData.successOf(history));
} catch (Exception e) {
log.error("채팅 기록 조회 중 오류 발생: ", e);
return ResponseEntity.internalServerError()
.body(RsData.failOf("서버 오류가 발생했습니다."));
}
}

@GetMapping("/history/user/{userId}/session/{sessionId}")
public ResponseEntity<RsData<List<ChatConversation>>> getUserChatHistory(
@PathVariable Long userId,
@PathVariable String sessionId) {
try {
List<ChatConversation> history = chatbotService.getUserChatHistory(userId, sessionId);
List<ChatConversation> history = chatbotService.getUserChatHistory(userId);
return ResponseEntity.ok(RsData.successOf(history));
} catch (Exception e) {
log.error("사용자 채팅 기록 조회 중 오류 발생: ", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ public class ChatRequestDto {
@NotBlank(message = "메시지는 필수입니다.")
private String message;

private String sessionId;

private Long userId;
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
public class ChatResponseDto {

private String response;
private String sessionId;
private LocalDateTime timestamp;

public ChatResponseDto(String response, String sessionId) {
public ChatResponseDto(String response) {
this.response = response;
this.sessionId = sessionId;
this.timestamp = LocalDateTime.now();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ public class ChatConversation {
@Column(columnDefinition = "TEXT")
private String botResponse;

private String sessionId;

private LocalDateTime createdAt;

@PrePersist
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
@Repository
public interface ChatConversationRepository extends JpaRepository<ChatConversation, Long> {

List<ChatConversation> findBySessionIdOrderByCreatedAtAsc(String sessionId);

Page<ChatConversation> findByUserIdOrderByCreatedAtDesc(Long userId, Pageable pageable);

List<ChatConversation> findByUserIdAndSessionIdOrderByCreatedAtAsc(Long userId, String sessionId);
List<ChatConversation> findTop5ByUserIdOrderByCreatedAtDesc(Long userId);
}
101 changes: 35 additions & 66 deletions src/main/java/com/back/domain/chatbot/service/ChatbotService.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,24 @@
import com.back.domain.chatbot.dto.ChatResponseDto;
import com.back.domain.chatbot.entity.ChatConversation;
import com.back.domain.chatbot.repository.ChatConversationRepository;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.data.domain.Pageable;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StreamUtils;

import jakarta.annotation.PostConstruct;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Collections;
import java.util.List;

@Service
@RequiredArgsConstructor
Expand All @@ -33,8 +31,6 @@ public class ChatbotService {
private final ChatModel chatModel;
private final ChatConversationRepository chatConversationRepository;

// 세션별 메모리 관리 (Thread-Safe)
private final ConcurrentHashMap<String, InMemoryChatMemory> sessionMemories = new ConcurrentHashMap<>();

@Value("classpath:prompts/chatbot-system-prompt.txt")
private Resource systemPromptResource;
Expand Down Expand Up @@ -80,23 +76,24 @@ public void init() throws IOException {

@Transactional
public ChatResponseDto sendMessage(ChatRequestDto requestDto) {
String sessionId = ensureSessionId(requestDto.getSessionId());

try {
// 메시지 타입 감지
MessageType messageType = detectMessageType(requestDto.getMessage());

// 세션별 메모리 가져오기
InMemoryChatMemory chatMemory = getOrCreateSessionMemory(sessionId);
// 최근 대화 기록 조회 (최신 5개)
List<ChatConversation> recentChats =
chatConversationRepository.findTop5ByUserIdOrderByCreatedAtDesc(requestDto.getUserId());

// 대화 히스토리를 시간순으로 정렬 (오래된 것부터)
Collections.reverse(recentChats);

// 이전 대화 기록 로드
loadConversationHistory(sessionId, chatMemory);
// 대화 컨텍스트 생성
String conversationContext = buildConversationContext(recentChats);

// ChatClient 빌더 생성
var promptBuilder = chatClient.prompt()
.system(buildSystemMessage(messageType))
.user(buildUserMessage(requestDto.getMessage(), messageType))
.advisors(new MessageChatMemoryAdvisor(chatMemory));
.system(buildSystemMessage(messageType) + conversationContext)
.user(buildUserMessage(requestDto.getMessage(), messageType));

// RAG 기능은 향후 구현 예정 (Vector DB 설정 필요)

Expand All @@ -109,42 +106,31 @@ public ChatResponseDto sendMessage(ChatRequestDto requestDto) {
// 응답 후처리
response = postProcessResponse(response, messageType);

// 대화 저장
saveConversation(requestDto, response, sessionId);
// 대화 저장 (sessionId 없이)
saveConversation(requestDto, response);

return new ChatResponseDto(response, sessionId);
return new ChatResponseDto(response);

} catch (Exception e) {
log.error("채팅 응답 생성 중 오류 발생: ", e);
return handleError(sessionId, e);
return handleError(e);
}
}

private String ensureSessionId(String sessionId) {
return (sessionId == null || sessionId.isEmpty())
? UUID.randomUUID().toString()
: sessionId;
}

private InMemoryChatMemory getOrCreateSessionMemory(String sessionId) {
return sessionMemories.computeIfAbsent(
sessionId,
k -> new InMemoryChatMemory()
);
}
private String buildConversationContext(List<ChatConversation> recentChats) {
if (recentChats.isEmpty()) {
return "";
}

private void loadConversationHistory(String sessionId, InMemoryChatMemory chatMemory) {
List<ChatConversation> conversations =
chatConversationRepository.findBySessionIdOrderByCreatedAtAsc(sessionId);

// 최근 N개의 대화만 메모리에 로드
String sessionIdForMemory = sessionId;
conversations.stream()
.skip(Math.max(0, conversations.size() - maxConversationCount))
.forEach(conv -> {
chatMemory.add(sessionIdForMemory, new UserMessage(conv.getUserMessage()));
chatMemory.add(sessionIdForMemory, new AssistantMessage(conv.getBotResponse()));
});
StringBuilder context = new StringBuilder("\n\n【최근 대화 기록】\n");
for (ChatConversation chat : recentChats) {
context.append("사용자: ").append(chat.getUserMessage()).append("\n");
context.append("봇: ").append(chat.getBotResponse()).append("\n\n");
}
context.append("위 대화를 참고하여 자연스럽게 이어지는 답변을 해주세요.\n");

return context.toString();
}

private String buildSystemMessage(MessageType type) {
Expand Down Expand Up @@ -208,19 +194,18 @@ private String postProcessResponse(String response, MessageType type) {
return response;
}

private void saveConversation(ChatRequestDto requestDto, String response, String sessionId) {
private void saveConversation(ChatRequestDto requestDto, String response) {
ChatConversation conversation = ChatConversation.builder()
.userId(requestDto.getUserId())
.userMessage(requestDto.getMessage())
.botResponse(response)
.sessionId(sessionId)
.createdAt(LocalDateTime.now())
.build();

chatConversationRepository.save(conversation);
}

private ChatResponseDto handleError(String sessionId, Exception e) {
private ChatResponseDto handleError(Exception e) {
String errorMessage = "죄송합니다. 잠시 후 다시 시도해주세요.";

if (e.getMessage().contains("rate limit")) {
Expand All @@ -229,7 +214,7 @@ private ChatResponseDto handleError(String sessionId, Exception e) {
errorMessage = "응답 시간이 초과되었습니다. 다시 시도해주세요.";
}

return new ChatResponseDto(errorMessage, sessionId);
return new ChatResponseDto(errorMessage);
}

public enum MessageType {
Expand All @@ -254,25 +239,9 @@ private MessageType detectMessageType(String message) {
}

@Transactional(readOnly = true)
public List<ChatConversation> getChatHistory(String sessionId) {
return chatConversationRepository.findBySessionIdOrderByCreatedAtAsc(sessionId);
}

@Transactional(readOnly = true)
public List<ChatConversation> getUserChatHistory(Long userId, String sessionId) {
return chatConversationRepository.findByUserIdAndSessionIdOrderByCreatedAtAsc(userId, sessionId);
public List<ChatConversation> getUserChatHistory(Long userId) {
return chatConversationRepository.findByUserIdOrderByCreatedAtDesc(userId, Pageable.unpaged()).getContent();
}

// 정기적인 메모리 정리 (스케줄러로 호출)
public void cleanupInactiveSessions() {
long thirtyMinutesAgo = System.currentTimeMillis() - (30 * 60 * 1000);

sessionMemories.entrySet().removeIf(entry -> {
// 실제로는 마지막 사용 시간을 추적해야 함
return false;
});

log.info("세션 메모리 정리 완료. 현재 활성 세션: {}", sessionMemories.size());
}
}