Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.back.koreaTravelGuide.common.config

import com.back.koreaTravelGuide.domain.ai.aiChat.tool.WeatherTool
import org.springframework.ai.chat.client.ChatClient
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor
import org.springframework.ai.chat.memory.ChatMemory
Expand Down Expand Up @@ -34,9 +35,11 @@ class AiConfig {
fun chatClient(
chatModel: ChatModel,
chatMemory: ChatMemory,
weatherTool: WeatherTool,
): ChatClient {
return ChatClient.builder(chatModel)
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build())
.defaultTools(weatherTool)
.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.back.koreaTravelGuide.domain.ai.aiChat.entity.AiChatSession
import com.back.koreaTravelGuide.domain.ai.aiChat.entity.SenderType
import com.back.koreaTravelGuide.domain.ai.aiChat.repository.AiChatMessageRepository
import com.back.koreaTravelGuide.domain.ai.aiChat.repository.AiChatSessionRepository
import org.slf4j.LoggerFactory
import org.springframework.ai.chat.client.ChatClient
import org.springframework.ai.chat.memory.ChatMemory
import org.springframework.stereotype.Service
Expand All @@ -18,8 +17,6 @@ class AiChatService(
private val aiChatSessionRepository: AiChatSessionRepository,
private val chatClient: ChatClient,
) {
private val logger = LoggerFactory.getLogger(AiChatService::class.java)

fun getSessions(userId: Long): List<AiChatSession> {
return aiChatSessionRepository.findByUserIdOrderByCreatedAtDesc(userId)
}
Expand Down Expand Up @@ -77,7 +74,6 @@ class AiChatService(
.call()
.content() ?: AI_ERROR_FALLBACK
} catch (e: Exception) {
logger.error("AI 응답 생성 실패: {}", e.message, e)
AI_ERROR_FALLBACK
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,63 +1,33 @@
package com.back.koreaTravelGuide.domain.ai.aiChat.tool

// TODO: AI 날씨 도구 - Spring AI @Tool 어노테이션으로 AI가 호출할 수 있는 날씨 기능
import com.back.koreaTravelGuide.domain.ai.weather.dto.MidForecastDto
import com.back.koreaTravelGuide.domain.ai.weather.dto.TemperatureAndLandForecastDto
import com.back.koreaTravelGuide.domain.ai.weather.service.WeatherServiceCore
import com.back.koreaTravelGuide.domain.ai.weather.service.WeatherService
import com.back.koreaTravelGuide.domain.ai.weather.service.tools.Tools
import org.springframework.ai.tool.annotation.Tool
import org.springframework.ai.tool.annotation.ToolParam
import org.springframework.stereotype.Service
import java.time.ZoneId
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
import org.springframework.stereotype.Component

@Service
@Component
class WeatherTool(
private val weatherServiceCore: WeatherServiceCore,
private val weatherService: WeatherService,
private val tools: Tools,
) {
@Tool(description = "현재 한국 시간을 조회합니다.")
fun getCurrentTime(): String {
val now = ZonedDateTime.now(ZoneId.of("Asia/Seoul"))
return "현재 한국 표준시(KST): ${now.format(DateTimeFormatter.ofPattern("yyyy년 MM월 dd일 HH시 mm분"))}"
}
@Tool(description = "전국 중기예보를 조회합니다")
fun getWeatherForecast(): String {
val actualBaseTime = tools.validBaseTime(null)
val forecasts = weatherService.fetchMidForecast(actualBaseTime)

@Tool(description = "전국 중기전망 텍스트를 조회해 여행하기 좋은 지역 후보를 파악합니다. 먼저 호출하여 비교할 지역 코드를 추려 주세요.")
fun queryMidTermNarrative(
@ToolParam(description = "발표 시각 (YYYYMMDDHHMM). 미지정 시 최근 발표시각 사용.", required = false) baseTime: String?,
): List<MidForecastDto>? {
return weatherServiceCore.getWeatherForecast(
baseTime = baseTime,
)
return forecasts?.toString() ?: "중기예보 데이터를 가져올 수 없습니다."
}

@Tool(description = "중기 기온과 강수 확률 지표를 지역별로 조회합니다. 첫 번째 툴에서 제안한 지역 코드로 비교 분석에 사용하세요.")
fun queryMidTermAndLandForecastMetrics(
@ToolParam(description = "지역 이름 (예: 서울, 인천)", required = true) location: String?,
@ToolParam(description = "중기예보 regId (예: [\"11B10101\", \"11H20301\"]).", required = true) regionCode: String?,
// @ToolParam(description = "중기예보 regId 배열 (예: [\"11B10101\", \"11H20301\"]).", required = true) regionCodes: List<String>,
@ToolParam(description = "발표 시각 (YYYYMMDDHHMM). 미지정 시 최근 발표시각 사용.", required = false) baseTime: String?,
// @ToolParam(description = "확인할 일 수 offset 목록 (4~10). 비워 두면 4~10일 모두 반환.", required = false) days: List<Int>?,
): List<TemperatureAndLandForecastDto>? {
return weatherServiceCore.getTemperatureAndLandForecast(
location = location,
regionCode = regionCode,
baseTime = baseTime,
)
}
@Tool(description = "특정 지역의 상세 기온 및 날씨 예보를 조회합니다")
fun getRegionalWeatherDetails(
@ToolParam(description = "지역명 (예: 서울, 부산, 대전, 제주 등)", required = true)
location: String,
): String {
val regionCode = tools.getRegionCodeFromLocation(location)
val actualBaseTime = tools.validBaseTime(null)
val forecasts = weatherService.fetchTemperatureAndLandForecast(regionCode, actualBaseTime)

// @Deprecated(
// message = "AI 툴 분리 이후에는 queryMidTermNarrative/queryMidTermMetrics를 사용하세요.",
// level = DeprecationLevel.WARNING,
// )
// fun getWeatherForecast(
// location: String?,
// regionCode: String?,
// baseTime: String?,
// ): WeatherResponse {
// return weatherService.getWeatherForecast(
// location = location,
// regionCode = regionCode,
// baseTime = baseTime,
// )
// }
return forecasts?.toString() ?: "$location 지역의 상세 날씨 정보를 가져올 수 없습니다."
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package com.back.koreaTravelGuide.domain.ai.aiChat.controller

import com.back.koreaTravelGuide.domain.ai.aiChat.dto.AiChatRequest
import com.fasterxml.jackson.databind.ObjectMapper
import io.github.cdimascio.dotenv.dotenv
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.http.MediaType
import org.springframework.security.test.context.support.WithMockUser
import org.springframework.test.context.ActiveProfiles
import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post
import org.springframework.transaction.annotation.Transactional

/**
* AI 채팅 컨트롤러 간단 테스트
* 응답 구조 확인 및 기본 동작 테스트
*/
@SpringBootTest
@AutoConfigureMockMvc
@ActiveProfiles("test")
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@Transactional
class AiChatControllerTest {
companion object {
@JvmStatic
@BeforeAll
fun loadEnv() {
val dotenv = dotenv { ignoreIfMissing = true }
dotenv.entries().forEach { entry ->
System.setProperty(entry.key, entry.value)
}
}
}

@Autowired
private lateinit var mockMvc: MockMvc

private val objectMapper = ObjectMapper()
private val userId = 1L

@Test
@WithMockUser
fun `AI 채팅 기본 동작 테스트`() {
println("========================================")
println("🤖 AI 채팅 기본 동작 테스트")
println("========================================")

// 1. 세션 생성
println("1️⃣ 새 채팅방 생성")
val createResult =
mockMvc.perform(
post("/api/aichat/sessions")
.param("userId", userId.toString()),
).andReturn()

println("📦 세션 생성 응답 상태: ${createResult.response.status}")
println("📦 세션 생성 응답 내용: ${createResult.response.contentAsString}")

if (createResult.response.status != 200) {
println("❌ 세션 생성 실패 - 테스트 중단")
return
}

// JSON 파싱해서 sessionId 추출
val createResponse = objectMapper.readTree(createResult.response.contentAsString)
val sessionId = createResponse.get("data").get("sessionId").asLong()
println("✅ 세션 생성 완료: sessionId=$sessionId")

// 2. 간단한 메시지 전송
println("2️⃣ AI에게 간단한 질문")
val chatRequest = AiChatRequest("안녕하세요!")

val messageResult =
mockMvc.perform(
post("/api/aichat/sessions/$sessionId/messages")
.param("userId", userId.toString())
.contentType(MediaType.APPLICATION_JSON)
.content(objectMapper.writeValueAsString(chatRequest)),
).andReturn()

println("📦 메시지 응답 상태: ${messageResult.response.status}")
println("📦 메시지 응답 내용: ${messageResult.response.contentAsString}")

if (messageResult.response.status == 200) {
println("✅ 메시지 전송 성공")
} else {
println("❌ 메시지 전송 실패")
}

// 3. 세션 목록 조회
println("3️⃣ 채팅방 목록 조회")
val sessionsResult =
mockMvc.perform(
get("/api/aichat/sessions")
.param("userId", userId.toString()),
).andReturn()

println("📦 세션 목록 응답 상태: ${sessionsResult.response.status}")
if (sessionsResult.response.status == 200) {
println("✅ 세션 목록 조회 성공")
} else {
println("❌ 세션 목록 조회 실패")
}

println("🎉 기본 동작 테스트 완료!")
}

@Test
@WithMockUser
fun `날씨 질문 테스트`() {
println("========================================")
println("🌤️ 날씨 질문 테스트")
println("========================================")

// 세션 생성
val createResult =
mockMvc.perform(
post("/api/aichat/sessions")
.param("userId", userId.toString()),
).andReturn()

if (createResult.response.status != 200) {
println("❌ 세션 생성 실패")
return
}

val sessionId =
objectMapper.readTree(createResult.response.contentAsString)
.get("data").get("sessionId").asLong()

// 날씨 질문
val weatherQuestions =
listOf(
"서울 날씨 어떤가요?",
"비 올까요?",
)

weatherQuestions.forEachIndexed { index, question ->
println("💬 질문 ${index + 1}: $question")

val chatRequest = AiChatRequest(question)
val result =
mockMvc.perform(
post("/api/aichat/sessions/$sessionId/messages")
.param("userId", userId.toString())
.contentType(MediaType.APPLICATION_JSON)
.content(objectMapper.writeValueAsString(chatRequest)),
).andReturn()

if (result.response.status == 200) {
println("✅ 응답 성공 (${result.response.contentLength} bytes)")
} else {
println("❌ 응답 실패: ${result.response.status}")
}
}

println("🎉 날씨 질문 테스트 완료!")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ class TourApiClientTest {
@Value("\${tour.api.key}")
private lateinit var serviceKey: String


@DisplayName("fetchTourInfo - 실제 관광청 API 호출 (데이터 기대)")
@Test
fun fetchTourInfoTest() {

val params =
TourSearchParams(
numOfRows = 1,
Expand All @@ -48,7 +46,6 @@ class TourApiClientTest {
@DisplayName("fetchTourInfo - 실제 관광청 API 장애 시 빈 결과 확인")
@Test
fun fetchTourInfoEmptyTest() {

val params =
TourSearchParams(
numOfRows = 1,
Expand Down
34 changes: 34 additions & 0 deletions src/test/resources/application-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
spring:
security:
oauth2:
client:
registration:
google:
client-id: test
client-secret: test
scope: profile, email
naver:
client-id: test
client-secret: test
authorization-grant-type: authorization_code
redirect-uri: "{baseUrl}/login/oauth2/code/{registrationId}"
scope: name, email
client-name: Naver
kakao:
client-id: test
client-secret: test
redirect-uri: "{baseUrl}/login/oauth2/code/{registrationId}"
authorization-grant-type: authorization_code
scope: profile_nickname, account_email
client-name: Kakao
provider:
naver:
authorization-uri: https://nid.naver.com/oauth2.0/authorize
token-uri: https://nid.naver.com/oauth2.0/token
user-info-uri: https://openapi.naver.com/v1/nid/me
user-name-attribute: response
kakao:
authorization-uri: https://kauth.kakao.com/oauth/authorize
token-uri: https://kauth.kakao.com/oauth/token
user-info-uri: https://kapi.kakao.com/v2/user/me
user-name-attribute: id