diff --git a/.gitignore b/.gitignore index bb2747e..7b7be94 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,12 @@ db_dev.trace.db **/generated/ **/build/generated/ src/main/generated/ -back/src/main/generated/ \ No newline at end of file +back/src/main/generated/ + +# AI 생성 이미지 저장 경로 +uploads/ +./uploads/ + +# 테스트 이미지 경로 (LocalStorageServiceTest) +test-uploads/ +./test-uploads/ \ No newline at end of file diff --git a/back/.gitignore b/back/.gitignore index 8ad286c..21cee85 100644 --- a/back/.gitignore +++ b/back/.gitignore @@ -38,3 +38,11 @@ out/ ### Environment Variables ### .env + +# AI 생성 이미지 저장 경로 +uploads/ +./uploads/ + +# 테스트 이미지 경로 (LocalStorageServiceTest) +test-uploads/ +./test-uploads/ diff --git a/back/build.gradle.kts b/back/build.gradle.kts index 258e9c4..dbec77d 100644 --- a/back/build.gradle.kts +++ b/back/build.gradle.kts @@ -79,6 +79,9 @@ dependencies { implementation("org.springframework.boot:spring-boot-starter-webflux") implementation("com.fasterxml.jackson.module:jackson-module-kotlin") + // AWS SDK for S3 + implementation("software.amazon.awssdk:s3:2.20.+") + // macOS Netty 네이티브 DNS 리졸버 (WebFlux 필요) val isMacOS: Boolean = System.getProperty("os.name").startsWith("Mac OS X") val architecture = System.getProperty("os.arch").lowercase() diff --git a/back/src/main/java/com/back/domain/scenario/controller/ScenarioController.java b/back/src/main/java/com/back/domain/scenario/controller/ScenarioController.java index 13646bf..3f83477 100644 --- a/back/src/main/java/com/back/domain/scenario/controller/ScenarioController.java +++ b/back/src/main/java/com/back/domain/scenario/controller/ScenarioController.java @@ -4,6 +4,8 @@ import com.back.domain.scenario.dto.*; import com.back.domain.scenario.service.ScenarioService; import com.back.global.common.PageResponse; +import com.back.global.exception.ApiException; +import com.back.global.exception.ErrorCode; import com.back.global.security.CustomUserDetails; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -30,12 +32,11 @@ public class ScenarioController { /** * 인증된 사용자의 ID를 안전하게 추출합니다. - * 테스트 환경에서 userDetails가 null일 수 있으므로 기본값을 제공합니다. + * 인증되지 않은 사용자는 예외를 발생시킵니다. */ private Long getUserId(CustomUserDetails userDetails) { - if (userDetails == null || userDetails.getUser() == null) { - // 테스트 환경이나 인증이 비활성화된 환경에서는 기본 사용자 ID 사용 - return 1L; + if (userDetails == null || userDetails.getUser() == null || userDetails.getUser().getId() == null) { + throw new ApiException(ErrorCode.HANDLE_ACCESS_DENIED, "인증이 필요한 서비스입니다."); } return userDetails.getUser().getId(); } @@ -49,6 +50,14 @@ public ResponseEntity createScenario( ) { Long userId = getUserId(userDetails); + // lastDecision 기본 검증: userId 일치 확인 + if (lastDecision != null && lastDecision.userId() != null) { + if (!lastDecision.userId().equals(userId)) { + throw new ApiException(ErrorCode.HANDLE_ACCESS_DENIED, + "lastDecision의 userId가 인증된 사용자와 일치하지 않습니다."); + } + } + ScenarioStatusResponse scenarioCreateResponse = scenarioService.createScenario(userId,request, lastDecision); return ResponseEntity.status(HttpStatus.CREATED).body(scenarioCreateResponse); diff --git a/back/src/main/java/com/back/domain/scenario/dto/AiScenarioGenerationResult.java b/back/src/main/java/com/back/domain/scenario/dto/AiScenarioGenerationResult.java new file mode 100644 index 0000000..2937620 --- /dev/null +++ b/back/src/main/java/com/back/domain/scenario/dto/AiScenarioGenerationResult.java @@ -0,0 +1,30 @@ +package com.back.domain.scenario.dto; + +import com.back.global.ai.dto.result.BaseScenarioResult; +import com.back.global.ai.dto.result.DecisionScenarioResult; +import lombok.Getter; + +/** + * AI 시나리오 생성 결과를 담는 래퍼 클래스. + * 트랜잭션 분리를 위해 베이스 시나리오와 결정 시나리오 결과를 구분하여 전달합니다. + */ +@Getter +public class AiScenarioGenerationResult { + private final boolean isBaseScenario; + private final BaseScenarioResult baseResult; + private final DecisionScenarioResult decisionResult; + + // 베이스 시나리오용 생성자 + public AiScenarioGenerationResult(BaseScenarioResult baseResult) { + this.isBaseScenario = true; + this.baseResult = baseResult; + this.decisionResult = null; + } + + // 결정 시나리오용 생성자 + public AiScenarioGenerationResult(DecisionScenarioResult decisionResult) { + this.isBaseScenario = false; + this.baseResult = null; + this.decisionResult = decisionResult; + } +} diff --git a/back/src/main/java/com/back/domain/scenario/entity/Scenario.java b/back/src/main/java/com/back/domain/scenario/entity/Scenario.java index 7af2a48..47fbf3b 100644 --- a/back/src/main/java/com/back/domain/scenario/entity/Scenario.java +++ b/back/src/main/java/com/back/domain/scenario/entity/Scenario.java @@ -15,7 +15,15 @@ * AI를 통해 생성된 시나리오의 상세 정보와 처리 상태를 저장합니다. */ @Entity -@Table(name = "scenarios") +@Table(name = "scenarios", + indexes = { + @Index(name = "idx_scenario_user_status", columnList = "user_id, status, created_date"), + @Index(name = "idx_scenario_baseline", columnList = "base_line_id") + }, + uniqueConstraints = { + @UniqueConstraint(name = "uk_scenario_decision_line", columnNames = "decision_line_id") + } +) @Getter @Setter @NoArgsConstructor @@ -30,7 +38,7 @@ public class Scenario extends BaseEntity { // 시나리오 생성의 기반이 된 선택 경로 @OneToOne(fetch = FetchType.LAZY) - @JoinColumn(name = "decision_line_id", unique = true) + @JoinColumn(name = "decision_line_id") private DecisionLine decisionLine; // 시나리오가 속한 베이스라인 (하나의 BaseLine에 여러 Scenario 가능) diff --git a/back/src/main/java/com/back/domain/scenario/repository/SceneTypeRepository.java b/back/src/main/java/com/back/domain/scenario/repository/SceneTypeRepository.java index 5d70c70..5c18aef 100644 --- a/back/src/main/java/com/back/domain/scenario/repository/SceneTypeRepository.java +++ b/back/src/main/java/com/back/domain/scenario/repository/SceneTypeRepository.java @@ -20,5 +20,9 @@ public interface SceneTypeRepository extends JpaRepository { @Query("SELECT st FROM SceneType st WHERE st.scenario.id IN :scenarioIds") List findByScenarioIdIn(@Param("scenarioIds") List scenarioIds); + // 여러 시나리오의 지표들을 배치 조회 (시나리오 ID, 타입 순서대로 정렬) + @Query("SELECT st FROM SceneType st WHERE st.scenario.id IN :scenarioIds ORDER BY st.scenario.id ASC, st.type ASC") + List findByScenarioIdInOrderByScenarioIdAscTypeAsc(@Param("scenarioIds") List scenarioIds); + List findByScenarioId(Long scenarioId); } \ No newline at end of file diff --git a/back/src/main/java/com/back/domain/scenario/service/ScenarioService.java b/back/src/main/java/com/back/domain/scenario/service/ScenarioService.java index e46ed03..655ff61 100644 --- a/back/src/main/java/com/back/domain/scenario/service/ScenarioService.java +++ b/back/src/main/java/com/back/domain/scenario/service/ScenarioService.java @@ -10,6 +10,7 @@ import com.back.domain.scenario.entity.Scenario; import com.back.domain.scenario.entity.ScenarioStatus; import com.back.domain.scenario.entity.SceneCompare; +import com.back.domain.scenario.entity.SceneType; import com.back.domain.scenario.repository.ScenarioRepository; import com.back.domain.scenario.repository.SceneCompareRepository; import com.back.domain.scenario.repository.SceneTypeRepository; @@ -19,10 +20,10 @@ import com.back.global.common.PageResponse; import com.back.global.exception.ApiException; import com.back.global.exception.ErrorCode; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import jakarta.annotation.Nullable; -import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.dao.DataIntegrityViolationException; @@ -68,94 +69,145 @@ public class ScenarioService { // 노드 서비스 추가(시나리오 생성과 동시에 마지막 노드 처리용) private final DecisionFlowService decisionFlowService; - // 시나리오 생성 - @Transactional + /** + * 시나리오 생성 요청 처리. + * 트랜잭션을 최소화하기 위해 검증 → 생성 → 비동기 트리거 순서로 분리. + */ public ScenarioStatusResponse createScenario(Long userId, ScenarioCreateRequest request, @Nullable DecisionNodeNextRequest lastDecision) { + // 1. 검증 및 기존 시나리오 확인 (읽기 전용) + ScenarioValidationResult validationResult = validateScenarioCreation(userId, request, lastDecision); + + if (validationResult.existingScenario != null) { + // 기존 시나리오가 있으면 상태에 따라 처리 + return handleExistingScenario(validationResult.existingScenario); + } + + // 2. 시나리오 생성 (짧은 트랜잭션) + Long scenarioId = createScenarioInTransaction( + userId, + request, + lastDecision, + validationResult.decisionLine + ); + + // 3. 비동기 AI 처리 트리거 (트랜잭션 외부) + processScenarioGenerationAsync(scenarioId); + + return new ScenarioStatusResponse( + scenarioId, + ScenarioStatus.PENDING, + "시나리오 생성이 시작되었습니다." + ); + } + + // 시나리오 생성 요청 검증 + private ScenarioValidationResult validateScenarioCreation( + Long userId, + ScenarioCreateRequest request, + @Nullable DecisionNodeNextRequest lastDecision) { + // DecisionLine 존재 여부 확인 DecisionLine decisionLine = decisionLineRepository.findById(request.decisionLineId()) .orElseThrow(() -> new ApiException(ErrorCode.DECISION_LINE_NOT_FOUND)); - // 권한 검증 (DecisionLine 소유자와 요청자 일치 여부) + // 권한 검증 if (!decisionLine.getUser().getId().equals(userId)) { throw new ApiException(ErrorCode.HANDLE_ACCESS_DENIED); } - // 원자적 조회 + 상태 확인 + ensureOwnerEditable(userId, decisionLine); + + // lastDecision 검증 + if (lastDecision != null) { + ensureSameLine(decisionLine, lastDecision); + } + + // 기존 시나리오 확인 (Unique Constraint로 동시성 제어) Optional existingScenario = scenarioRepository .findByDecisionLineId(request.decisionLineId()); - if (existingScenario.isPresent()) { - Scenario existing = existingScenario.get(); - - // PENDING/PROCESSING 상태면 중복 생성 방지 - if (existing.getStatus() == ScenarioStatus.PENDING || - existing.getStatus() == ScenarioStatus.PROCESSING) { - throw new ApiException(ErrorCode.SCENARIO_ALREADY_IN_PROGRESS, - "해당 선택 경로의 시나리오가 이미 생성 중입니다."); - } + return new ScenarioValidationResult(decisionLine, existingScenario.orElse(null)); + } - // FAILED 상태면 재시도 로직 - if (existing.getStatus() == ScenarioStatus.FAILED) { - return handleFailedScenarioRetry(existing); - } + /** + * 기존 시나리오 처리 로직 + */ + private ScenarioStatusResponse handleExistingScenario(Scenario existing) { + ScenarioStatus status = existing.getStatus(); - // COMPLETED 상태면 기존 시나리오 반환 - if (existing.getStatus() == ScenarioStatus.COMPLETED) { - return new ScenarioStatusResponse( - existing.getId(), - existing.getStatus(), - "이미 완료된 시나리오가 존재합니다." - ); - } + // PENDING/PROCESSING 상태면 중복 생성 방지 + if (status == ScenarioStatus.PENDING || status == ScenarioStatus.PROCESSING) { + throw new ApiException(ErrorCode.SCENARIO_ALREADY_IN_PROGRESS, + "해당 선택 경로의 시나리오가 이미 생성 중입니다."); } - ensureOwnerEditable(userId, decisionLine); - - if (lastDecision != null) { - ensureSameLine(decisionLine, lastDecision); - decisionFlowService.createDecisionNodeNext(lastDecision); + // FAILED 상태면 재시도 로직 + if (status == ScenarioStatus.FAILED) { + return handleFailedScenarioRetry(existing); } - // 라인 완료 처리(외부 완료 API 제거 시 내부에서만 호출) - try { decisionLine.complete(); } catch (RuntimeException e) { - throw new ApiException(ErrorCode.INVALID_INPUT_VALUE, e.getMessage()); - } + // COMPLETED 상태면 기존 시나리오 반환 + return new ScenarioStatusResponse( + existing.getId(), + existing.getStatus(), + "이미 완료된 시나리오가 존재합니다." + ); + } + + /** + * 시나리오 생성 트랜잭션 (최소한의 DB 작업만 수행) + */ + @Transactional + protected Long createScenarioInTransaction( + Long userId, + ScenarioCreateRequest request, + @Nullable DecisionNodeNextRequest lastDecision, + DecisionLine decisionLine) { - // 새 시나리오 생성 (DataIntegrityViolationException 처리) try { - // DecisionLine에서 BaseLine 가져오기 - BaseLine baseLine = decisionLine.getBaseLine(); + // lastDecision 처리 (필요 시) + if (lastDecision != null) { + decisionFlowService.createDecisionNodeNext(lastDecision); + } + // DecisionLine 완료 처리 + try { + decisionLine.complete(); + } catch (RuntimeException e) { + throw new ApiException(ErrorCode.INVALID_INPUT_VALUE, e.getMessage()); + } + + // 시나리오 생성 + BaseLine baseLine = decisionLine.getBaseLine(); Scenario scenario = Scenario.builder() .user(decisionLine.getUser()) .decisionLine(decisionLine) - .baseLine(baseLine) // DecisionLine의 BaseLine 연결 + .baseLine(baseLine) .status(ScenarioStatus.PENDING) .build(); Scenario savedScenario = scenarioRepository.save(scenario); - processScenarioGenerationAsync(savedScenario.getId()); - return new ScenarioStatusResponse( - savedScenario.getId(), - savedScenario.getStatus(), - "시나리오 생성이 시작되었습니다." - ); + return savedScenario.getId(); } catch (DataIntegrityViolationException e) { - // 동시성으로 인한 중복 생성 시 기존 시나리오 조회 후 반환 + // 동시성으로 인한 중복 생성 시 기존 시나리오 ID 반환 return scenarioRepository.findByDecisionLineId(request.decisionLineId()) - .map(existing -> new ScenarioStatusResponse( - existing.getId(), - existing.getStatus(), - "기존 시나리오를 반환합니다." - )) + .map(Scenario::getId) .orElseThrow(() -> new ApiException(ErrorCode.SCENARIO_CREATION_FAILED)); } } + /** + * 검증 결과를 담는 내부 클래스 + */ + private record ScenarioValidationResult( + DecisionLine decisionLine, + Scenario existingScenario + ) {} + // 가장 많이 사용하는 함수 호출 위 한줄 요약: 시나리오 요청에서 lineId 필수 추출 private Long requireLineId(ScenarioCreateRequest scenario) { if (scenario == null || scenario.decisionLineId() == null) { @@ -212,24 +264,43 @@ private void ensureSameLine(DecisionLine line, DecisionNodeNextRequest lastDecis } - // FAILED 시나리오 재시도 로직 분리 + /** + * FAILED 시나리오 재시도 로직. + * 트랜잭션과 비동기 처리를 분리하여 커넥션 풀 효율성 향상. + */ private ScenarioStatusResponse handleFailedScenarioRetry(Scenario failedScenario) { - failedScenario.setStatus(ScenarioStatus.PENDING); - failedScenario.setErrorMessage(null); - failedScenario.setUpdatedDate(LocalDateTime.now()); + // 1. 상태 업데이트 (트랜잭션) + Long scenarioId = retryScenarioInTransaction(failedScenario.getId()); - Scenario savedScenario = scenarioRepository.save(failedScenario); - processScenarioGenerationAsync(savedScenario.getId()); + // 2. 비동기 AI 처리 트리거 (트랜잭션 외부) + processScenarioGenerationAsync(scenarioId); return new ScenarioStatusResponse( - savedScenario.getId(), - savedScenario.getStatus(), + scenarioId, + ScenarioStatus.PENDING, "시나리오 재생성이 시작되었습니다." ); } + /** + * FAILED 시나리오를 PENDING으로 되돌리는 트랜잭션 + */ + @Transactional + protected Long retryScenarioInTransaction(Long scenarioId) { + Scenario scenario = scenarioRepository.findById(scenarioId) + .orElseThrow(() -> new ApiException(ErrorCode.SCENARIO_NOT_FOUND)); + + scenario.setStatus(ScenarioStatus.PENDING); + scenario.setErrorMessage(null); + scenario.setUpdatedDate(LocalDateTime.now()); + + scenarioRepository.save(scenario); + + return scenario.getId(); + } + // 비동기 방식으로 AI 시나리오 생성 - @Async + @Async("aiTaskExecutor") public void processScenarioGenerationAsync(Long scenarioId) { try { // 1. 상태를 PROCESSING으로 업데이트 (별도 트랜잭션) @@ -383,8 +454,9 @@ private Map parseTimelineTitles(String timelineTitles) { // JSON 문자열을 Map으로 파싱 return objectMapper.readValue(timelineTitles, new TypeReference>() {}); - } catch (Exception e) { + } catch (JsonProcessingException e) { // JSON 파싱 실패 시 예외 처리 + log.error("Failed to parse timeline JSON: {}", e.getMessage()); throw new ApiException(ErrorCode.SCENARIO_TIMELINE_NOT_FOUND); } } @@ -392,17 +464,43 @@ private Map parseTimelineTitles(String timelineTitles) { // 시나리오 비교 분석 @Transactional(readOnly = true) public ScenarioCompareResponse compareScenarios(Long baseId, Long compareId, Long userId) { - // 권한 검증 및 시나리오 조회 - Scenario baseScenario = scenarioRepository.findByIdAndUserId(baseId, userId) + // 1. 두 시나리오를 배치 조회 (권한 검증 포함) + List scenarios = scenarioRepository.findAllById(List.of(baseId, compareId)); + + // 존재 여부 및 권한 검증 + if (scenarios.size() != 2) { + throw new ApiException(ErrorCode.SCENARIO_NOT_FOUND); + } + + Scenario baseScenario = scenarios.stream() + .filter(s -> s.getId().equals(baseId)) + .findFirst() .orElseThrow(() -> new ApiException(ErrorCode.SCENARIO_NOT_FOUND)); - Scenario compareScenario = scenarioRepository.findByIdAndUserId(compareId, userId) + + Scenario compareScenario = scenarios.stream() + .filter(s -> s.getId().equals(compareId)) + .findFirst() .orElseThrow(() -> new ApiException(ErrorCode.SCENARIO_NOT_FOUND)); - // 지표 조회 - var baseTypes = sceneTypeRepository.findByScenarioIdOrderByTypeAsc(baseId); - var compareTypes = sceneTypeRepository.findByScenarioIdOrderByTypeAsc(compareId); + // 권한 검증 + if (!baseScenario.getUser().getId().equals(userId) || + !compareScenario.getUser().getId().equals(userId)) { + throw new ApiException(ErrorCode.HANDLE_ACCESS_DENIED); + } + + // 2. 두 시나리오의 지표를 배치 조회 + List allSceneTypes = sceneTypeRepository.findByScenarioIdInOrderByScenarioIdAscTypeAsc( + List.of(baseId, compareId)); - // 비교 분석 결과 조회 + var baseTypes = allSceneTypes.stream() + .filter(st -> st.getScenario().getId().equals(baseId)) + .toList(); + + var compareTypes = allSceneTypes.stream() + .filter(st -> st.getScenario().getId().equals(compareId)) + .toList(); + + // 3. 비교 분석 결과 조회 List compareResults = sceneCompareRepository.findByScenarioIdOrderByResultType(compareId); if (compareResults.isEmpty()) { throw new ApiException(ErrorCode.SCENE_COMPARE_NOT_FOUND); @@ -461,26 +559,4 @@ private String convertCategoryToKorean(NodeCategory category) { }; } - // AI 생성 결과를 담는 래퍼 클래스 (트랜잭션 분리용) - @Getter - static class AiScenarioGenerationResult { - private final boolean isBaseScenario; - private final BaseScenarioResult baseResult; - private final DecisionScenarioResult decisionResult; - - // 베이스 시나리오용 생성자 - public AiScenarioGenerationResult(BaseScenarioResult baseResult) { - this.isBaseScenario = true; - this.baseResult = baseResult; - this.decisionResult = null; - } - - // 결정 시나리오용 생성자 - public AiScenarioGenerationResult(DecisionScenarioResult decisionResult) { - this.isBaseScenario = false; - this.baseResult = null; - this.decisionResult = decisionResult; - } - - } } \ No newline at end of file diff --git a/back/src/main/java/com/back/domain/scenario/service/ScenarioTransactionService.java b/back/src/main/java/com/back/domain/scenario/service/ScenarioTransactionService.java index 22512a8..337f841 100644 --- a/back/src/main/java/com/back/domain/scenario/service/ScenarioTransactionService.java +++ b/back/src/main/java/com/back/domain/scenario/service/ScenarioTransactionService.java @@ -2,6 +2,7 @@ import com.back.domain.node.entity.BaseLine; import com.back.domain.node.repository.BaseLineRepository; +import com.back.domain.scenario.dto.AiScenarioGenerationResult; import com.back.domain.scenario.entity.*; import com.back.domain.scenario.repository.ScenarioRepository; import com.back.domain.scenario.repository.SceneCompareRepository; @@ -11,6 +12,7 @@ import com.back.global.ai.service.AiService; import com.back.global.exception.ApiException; import com.back.global.exception.ErrorCode; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -61,7 +63,7 @@ public void updateScenarioStatus(Long scenarioId, ScenarioStatus status, String // AI 결과 저장 전용 트랜잭션 메서드 @Transactional(propagation = Propagation.REQUIRES_NEW) - public void saveAiResult(Long scenarioId, ScenarioService.AiScenarioGenerationResult result) { + public void saveAiResult(Long scenarioId, AiScenarioGenerationResult result) { Scenario scenario = scenarioRepository.findById(scenarioId) .orElseThrow(() -> new ApiException(ErrorCode.SCENARIO_NOT_FOUND)); @@ -125,7 +127,7 @@ private void handleTimelineTitles(Scenario scenario, Map timelin try { String timelineTitlesJson = objectMapper.writeValueAsString(timelineTitles); scenario.setTimelineTitles(timelineTitlesJson); - } catch (Exception e) { + } catch (JsonProcessingException e) { log.error("Failed to serialize timeline titles for scenario {}: {}", scenario.getId(), e.getMessage()); scenario.setTimelineTitles("{}"); @@ -135,7 +137,14 @@ private void handleTimelineTitles(Scenario scenario, Map timelin private void handleImageGeneration(Scenario scenario, String imagePrompt) { try { if (imagePrompt != null && !imagePrompt.trim().isEmpty()) { - String imageUrl = aiService.generateImage(imagePrompt).join(); + String imageUrl = aiService.generateImage(imagePrompt) + .orTimeout(60, java.util.concurrent.TimeUnit.SECONDS) + .exceptionally(ex -> { + log.warn("Image generation timeout or error for scenario {}: {}", + scenario.getId(), ex.getMessage()); + return null; + }) + .join(); if ("placeholder-image-url".equals(imageUrl) || imageUrl == null || imageUrl.trim().isEmpty()) { scenario.setImg(null); diff --git a/back/src/main/java/com/back/global/ai/client/image/StableDiffusionImageClient.java b/back/src/main/java/com/back/global/ai/client/image/StableDiffusionImageClient.java new file mode 100644 index 0000000..e17db4a --- /dev/null +++ b/back/src/main/java/com/back/global/ai/client/image/StableDiffusionImageClient.java @@ -0,0 +1,153 @@ +package com.back.global.ai.client.image; + +import com.back.global.ai.config.ImageAiConfig; +import com.back.global.ai.exception.AiServiceException; +import com.back.global.exception.ErrorCode; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.http.HttpHeaders; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.stereotype.Component; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Stable Diffusion 3.5 Large Turbo 이미지 생성 클라이언트 + * Stability AI API를 사용하여 고품질 이미지를 생성합니다. + */ +@Slf4j +@Component +@RequiredArgsConstructor +@ConditionalOnProperty(prefix = "ai.image", name = "enabled", havingValue = "true") +public class StableDiffusionImageClient implements ImageAiClient { + + private final ImageAiConfig imageAiConfig; + private final WebClient webClient; + private final ObjectMapper objectMapper = new ObjectMapper(); + + @Override + public CompletableFuture generateImage(String prompt) { + return generateImage(prompt, Map.of()); + } + + @Override + public CompletableFuture generateImage(String prompt, Map options) { + log.info("Generating image with Stable Diffusion 3.5 Large Turbo. Prompt: {}", prompt); + + // Multipart 요청 바디 구성 + MultipartBodyBuilder bodyBuilder = buildMultipartRequestBody(prompt, options); + + return webClient.post() + .uri(imageAiConfig.getBaseUrl() + "/v2beta/stable-image/generate/sd3") + .header(HttpHeaders.AUTHORIZATION, "Bearer " + imageAiConfig.getApiKey()) + .header(HttpHeaders.ACCEPT, "application/json") + .body(BodyInserters.fromMultipartData(bodyBuilder.build())) + .retrieve() + .bodyToMono(String.class) + .timeout(Duration.ofSeconds(imageAiConfig.getTimeoutSeconds())) + .doOnError(error -> log.error("Stable Diffusion API error: {}", error.getMessage())) + .retryWhen(reactor.util.retry.Retry.fixedDelay( + imageAiConfig.getMaxRetries(), + Duration.ofSeconds(imageAiConfig.getRetryDelaySeconds()) + )) + .flatMap(this::extractImageData) + .toFuture(); + } + + @Override + public boolean isEnabled() { + return imageAiConfig.isEnabled(); + } + + /** + * Stable Diffusion API Multipart 요청 바디를 구성합니다. + * SD 3.5 Large Turbo는 multipart/form-data 형식을 사용합니다. + */ + private MultipartBodyBuilder buildMultipartRequestBody(String prompt, Map options) { + MultipartBodyBuilder builder = new MultipartBodyBuilder(); + + // 만화풍 스타일 추가 + String enhancedPrompt = prompt + ", cartoon style, animated illustration, vibrant colors, clean lines"; + builder.part("prompt", enhancedPrompt); + + // 모델 지정 (SD 3.5 Large Turbo) + builder.part("model", "sd3.5-large-turbo"); + + // 출력 형식 + builder.part("output_format", "jpeg"); + + // 옵션에서 값 추출 (기본값 사용) + builder.part("aspect_ratio", options.getOrDefault("aspect_ratio", "1:1")); + + if (options.containsKey("seed")) { + builder.part("seed", options.get("seed").toString()); + } + + // 네거티브 프롬프트 (품질 향상 + 실사 스타일 배제) + String negativePrompt = options.containsKey("negative_prompt") + ? options.get("negative_prompt").toString() + : "blurry, low quality, distorted, deformed, realistic, photo, photography"; + builder.part("negative_prompt", negativePrompt); + + return builder; + } + + /** + * API 응답에서 이미지 데이터를 추출합니다. + * SD 3.5 Large Turbo의 응답 구조: { "artifacts": [{ "base64": "...", "finishReason": "SUCCESS" }] } + * + * @param response API 응답 JSON + * @return 이미지 Base64 데이터 + */ + private Mono extractImageData(String response) { + try { + JsonNode rootNode = objectMapper.readTree(response); + + // Stability AI 공식 응답 구조: { "artifacts": [{ "base64": "..." }] } + if (rootNode.has("artifacts") && rootNode.get("artifacts").isArray()) { + JsonNode firstArtifact = rootNode.get("artifacts").get(0); + + // finishReason 검증 + if (firstArtifact.has("finishReason")) { + String finishReason = firstArtifact.get("finishReason").asText(); + if (!"SUCCESS".equals(finishReason)) { + log.error("Image generation failed with reason: {}", finishReason); + return Mono.error(new AiServiceException( + ErrorCode.AI_GENERATION_FAILED, + "Image generation failed: " + finishReason + )); + } + } + + // Base64 데이터 추출 + if (firstArtifact.has("base64")) { + String base64Data = firstArtifact.get("base64").asText(); + log.info("Image generated successfully. Base64 length: {}", base64Data.length()); + return Mono.just(base64Data); + } + } + + // 응답 구조가 예상과 다를 경우 + log.error("Unexpected Stable Diffusion API response structure: {}", response); + return Mono.error(new AiServiceException( + ErrorCode.AI_GENERATION_FAILED, + "Failed to extract image data from API response" + )); + + } catch (Exception e) { + log.error("Error parsing Stable Diffusion API response: {}", e.getMessage()); + return Mono.error(new AiServiceException( + ErrorCode.AI_GENERATION_FAILED, + "Failed to parse image generation response: " + e.getMessage() + )); + } + } +} diff --git a/back/src/main/java/com/back/global/ai/client/text/GeminiTextClient.java b/back/src/main/java/com/back/global/ai/client/text/GeminiTextClient.java index d9cd42f..9b88ac3 100644 --- a/back/src/main/java/com/back/global/ai/client/text/GeminiTextClient.java +++ b/back/src/main/java/com/back/global/ai/client/text/GeminiTextClient.java @@ -27,7 +27,6 @@ */ @Component @Slf4j -// TODO: AI 예외 처리 api구체화, DTO 구조 구체화, API키 추가 public class GeminiTextClient implements TextAiClient { private final WebClient webClient; @@ -57,7 +56,7 @@ public CompletableFuture generateText(AiRequest aiRequest) { .bodyToMono(GeminiResponse.class) .map(this::extractContent) .timeout(Duration.ofSeconds(textAiConfig.getTimeoutSeconds())) - .retryWhen(Retry.backoff(textAiConfig.getMaxRetries(), Duration.ofSeconds(2))) + .retryWhen(Retry.backoff(textAiConfig.getMaxRetries(), Duration.ofSeconds(textAiConfig.getRetryDelaySeconds()))) .doOnError(error -> log.error("Gemini API call failed: {}", error.getMessage())) .toFuture(); } diff --git a/back/src/main/java/com/back/global/ai/config/ImageAiConfig.java b/back/src/main/java/com/back/global/ai/config/ImageAiConfig.java index 7fe14ae..753908c 100644 --- a/back/src/main/java/com/back/global/ai/config/ImageAiConfig.java +++ b/back/src/main/java/com/back/global/ai/config/ImageAiConfig.java @@ -1,21 +1,57 @@ package com.back.global.ai.config; -import lombok.Data; +import lombok.Getter; +import lombok.Setter; import org.springframework.boot.context.properties.ConfigurationProperties; -import org.springframework.context.annotation.Configuration; +import org.springframework.stereotype.Component; /** - * 이미지 생성 AI 서비스 설정 클래스 - * 이미지 생성 AI의 기본 설정값들을 관리합니다. + * 이미지 생성 AI 설정 프로퍼티 + * application.yml의 ai.image 설정을 바인딩합니다. */ -@Configuration +@Getter +@Setter +@Component @ConfigurationProperties(prefix = "ai.image") -@Data public class ImageAiConfig { + + // 이미지 AI 기능 활성화 여부 private boolean enabled = false; - private String provider = "placeholder"; + + // 이미지 AI 제공자 + private String provider = "stable-diffusion"; + + private String apiKey; + + private String baseUrl = "https://api.stability.ai"; + private int timeoutSeconds = 60; + private int maxRetries = 3; - // TODO: 추후 설정 추가 예정 -} \ No newline at end of file + private int retryDelaySeconds = 2; // 재시도 간격 (초) + + // 이미지 저장 방식 (s3, local 등) + private String storageType = "local"; + + // AWS S3 버킷 이름 (storageType이 s3인 경우) + private String s3BucketName; + + // AWS S3 리전 (storageType이 s3인 경우) + private String s3Region; + + // 로컬 파일 저장 경로 (storageType="local"인 경우 사용) + // 기본값: "./uploads/images" + private String localStoragePath = "./uploads/images"; + + // 기본값: "http://localhost:8080/images" + private String localBaseUrl = "http://localhost:8080/images"; + + public boolean isS3Enabled() { + return "s3".equalsIgnoreCase(storageType); + } + + public boolean isLocalEnabled() { + return "local".equalsIgnoreCase(storageType); + } +} diff --git a/back/src/main/java/com/back/global/ai/config/TextAiConfig.java b/back/src/main/java/com/back/global/ai/config/TextAiConfig.java index d2bacef..cfdaae8 100644 --- a/back/src/main/java/com/back/global/ai/config/TextAiConfig.java +++ b/back/src/main/java/com/back/global/ai/config/TextAiConfig.java @@ -24,6 +24,7 @@ public class TextAiConfig { String model = "gemini-2.5-flash"; // 추후 변경 가능 int timeoutSeconds = 30; int maxRetries = 3; + int retryDelaySeconds = 2; // 재시도 간격 (초) /** * Gemini API 전용 WebClient Bean 생성 diff --git a/back/src/main/java/com/back/global/ai/exception/AiServiceException.java b/back/src/main/java/com/back/global/ai/exception/AiServiceException.java index 9385449..57d2442 100644 --- a/back/src/main/java/com/back/global/ai/exception/AiServiceException.java +++ b/back/src/main/java/com/back/global/ai/exception/AiServiceException.java @@ -1,25 +1,20 @@ package com.back.global.ai.exception; +import com.back.global.exception.ApiException; import com.back.global.exception.ErrorCode; /** * AI 서비스 관련 예외의 기본 클래스. + * ApiException을 상속받아 GlobalExceptionHandler에서 일관된 예외 처리를 보장합니다. * 모든 AI 관련 예외는 이 클래스를 상속받아 구현됩니다. */ -public class AiServiceException extends RuntimeException { - private final ErrorCode errorCode; +public class AiServiceException extends ApiException { public AiServiceException(ErrorCode errorCode) { - super(errorCode.getMessage()); - this.errorCode = errorCode; + super(errorCode); } public AiServiceException(ErrorCode errorCode, String customMessage) { - super(customMessage); - this.errorCode = errorCode; - } - - public ErrorCode getErrorCode() { - return errorCode; + super(errorCode, customMessage); } } diff --git a/back/src/main/java/com/back/global/ai/prompt/BaseScenarioPrompt.java b/back/src/main/java/com/back/global/ai/prompt/BaseScenarioPrompt.java index 0eed705..16d1b11 100644 --- a/back/src/main/java/com/back/global/ai/prompt/BaseScenarioPrompt.java +++ b/back/src/main/java/com/back/global/ai/prompt/BaseScenarioPrompt.java @@ -155,4 +155,22 @@ public static String generatePrompt(BaseLine baseLine) { .replace("{baseNodes}", baseNodesInfo.toString()) .replace("{timelineYears}", timelineYears.toString()); } + + /** + * 예상 토큰 수를 계산한다. (로깅 목적) + * 베이스 시나리오는 중간 크기의 응답을 요구한다. + * + * @param baseLine 토큰 수 계산할 베이스라인 + * @return 예상 토큰 수 + */ + public static int estimateTokens(BaseLine baseLine) { + int baseTokens = 800; // 기본 프롬프트 토큰 수 (사용자 정보 포함) + + if (baseLine != null && baseLine.getBaseNodes() != null) { + // BaseNode당 약 50토큰 (카테고리, 나이, 상황, 결정 포함) + baseTokens += baseLine.getBaseNodes().size() * 50; + } + + return baseTokens; + } } \ No newline at end of file diff --git a/back/src/main/java/com/back/global/ai/prompt/DecisionScenarioPrompt.java b/back/src/main/java/com/back/global/ai/prompt/DecisionScenarioPrompt.java index 06dec48..8f9b5e0 100644 --- a/back/src/main/java/com/back/global/ai/prompt/DecisionScenarioPrompt.java +++ b/back/src/main/java/com/back/global/ai/prompt/DecisionScenarioPrompt.java @@ -210,4 +210,23 @@ private static int getScoreByType(List sceneTypes, String typeName) { .findFirst() .orElse(50); } + + /** + * 예상 토큰 수를 계산한다. (로깅 목적) + * 결정 시나리오는 베이스 시나리오 정보와 선택 경로를 모두 포함하여 가장 큰 프롬프트이다. + * + * @param decisionLine 선택 경로 + * @param baseScenario 베이스 시나리오 + * @return 예상 토큰 수 + */ + public static int estimateTokens(DecisionLine decisionLine, Scenario baseScenario) { + int baseTokens = 1200; // 기본 프롬프트 토큰 수 (사용자 정보 + 베이스 시나리오 정보 포함) + + if (decisionLine != null && decisionLine.getDecisionNodes() != null) { + // DecisionNode당 약 80토큰 (상황 + 결정 상세 포함) + baseTokens += decisionLine.getDecisionNodes().size() * 80; + } + + return baseTokens; + } } \ No newline at end of file diff --git a/back/src/main/java/com/back/global/ai/prompt/SituationPrompt.java b/back/src/main/java/com/back/global/ai/prompt/SituationPrompt.java index 4847a2f..7ffa630 100644 --- a/back/src/main/java/com/back/global/ai/prompt/SituationPrompt.java +++ b/back/src/main/java/com/back/global/ai/prompt/SituationPrompt.java @@ -3,18 +3,21 @@ import com.back.domain.node.entity.DecisionNode; import com.back.global.ai.exception.AiServiceException; import com.back.global.exception.ErrorCode; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.extern.slf4j.Slf4j; + import java.util.List; /** * 상황 생성을 위한 프롬프트 템플릿 * Trees 도메인에서 사용되며, 이전 선택들의 나비효과로 새로운 상황을 생성한다. + * ObjectMapper는 Spring Context에서 주입받아 사용한다. */ +@Slf4j public class SituationPrompt { - private static final ObjectMapper objectMapper = new ObjectMapper(); - private static final String PROMPT_TEMPLATE = """ 당신은 인생 시뮬레이션 전문가입니다. 이전 선택들의 나비효과로 발생한 새로운 상황을 생성하세요. @@ -118,7 +121,7 @@ public static String generatePrompt(List previousNodes, String tim decision = decision.substring(0, 12) + "..."; } - choicesInfo.append(String.format("%d세 %s", node.getAgeYear(), decision)); + choicesInfo.append(node.getAgeYear()).append("세 ").append(decision); if (i < previousNodes.size() - 1) { choicesInfo.append(" → "); } @@ -136,14 +139,13 @@ public static String generatePrompt(List previousNodes, String tim DecisionNode node = previousNodes.get(i); int actualYear = birthYear + node.getAgeYear() - 1; // 실제 연도 계산 - choicesInfo.append(String.format( - "%d단계 (%d세, %d년):\n상황: %s\n선택: %s\n\n", - i + 1, - node.getAgeYear(), - actualYear, - node.getSituation() != null ? node.getSituation() : "상황 정보 없음", - node.getDecision() != null ? node.getDecision() : "선택 정보 없음" - )); + choicesInfo.append(i + 1).append("단계 (") + .append(node.getAgeYear()).append("세, ") + .append(actualYear).append("년):\n상황: ") + .append(node.getSituation() != null ? node.getSituation() : "상황 정보 없음") + .append("\n선택: ") + .append(node.getDecision() != null ? node.getDecision() : "선택 정보 없음") + .append("\n\n"); } String timeContextValue; @@ -247,10 +249,12 @@ public static int estimateTokens(List previousNodes) { * JSON 형식 응답에서 situation 필드를 파싱한다. * * @param aiResponse AI의 전체 응답 + * @param objectMapper JSON 파싱용 ObjectMapper (Spring Context에서 주입) * @return 상황 텍스트만 추출된 결과 */ - public static String extractSituation(String aiResponse) { + public static String extractSituation(String aiResponse, ObjectMapper objectMapper) { if (aiResponse == null || aiResponse.trim().isEmpty()) { + log.warn("AI response is null or empty"); return "상황 생성에 실패했습니다."; } @@ -265,8 +269,9 @@ public static String extractSituation(String aiResponse) { return situation; } } - } catch (Exception e) { + } catch (JsonProcessingException e) { // JSON 파싱 실패 시 기존 방식으로 fallback + log.warn("Failed to parse AI response as JSON, falling back to text extraction: {}", e.getMessage()); } // 기존 방식으로 fallback @@ -285,9 +290,10 @@ public static String extractSituation(String aiResponse) { * JSON 형식 응답에서 recommendedOption 필드를 파싱한다. * * @param aiResponse AI의 전체 응답 + * @param objectMapper JSON 파싱용 ObjectMapper (Spring Context에서 주입) * @return 추천 선택지 텍스트, 추출 실패 시 null */ - public static String extractRecommendedOption(String aiResponse) { + public static String extractRecommendedOption(String aiResponse, ObjectMapper objectMapper) { if (aiResponse == null || aiResponse.trim().isEmpty()) { return null; } @@ -301,8 +307,9 @@ public static String extractRecommendedOption(String aiResponse) { String option = rootNode.get("recommendedOption").asText(); return (option != null && !option.trim().isEmpty()) ? option : null; } - } catch (Exception e) { + } catch (JsonProcessingException e) { // JSON 파싱 실패 시 null 반환 + log.warn("Failed to parse AI response for recommendedOption: {}", e.getMessage()); return null; } diff --git a/back/src/main/java/com/back/global/ai/service/AiServiceImpl.java b/back/src/main/java/com/back/global/ai/service/AiServiceImpl.java index 4acf6d0..4c17f82 100644 --- a/back/src/main/java/com/back/global/ai/service/AiServiceImpl.java +++ b/back/src/main/java/com/back/global/ai/service/AiServiceImpl.java @@ -6,6 +6,7 @@ import com.back.domain.scenario.entity.Scenario; import com.back.domain.scenario.entity.SceneType; import com.back.domain.scenario.repository.SceneTypeRepository; +import com.back.global.ai.client.image.ImageAiClient; import com.back.global.ai.client.text.TextAiClient; import com.back.global.ai.config.BaseScenarioAiProperties; import com.back.global.ai.config.DecisionScenarioAiProperties; @@ -14,6 +15,7 @@ import com.back.global.ai.dto.result.BaseScenarioResult; import com.back.global.ai.dto.result.DecisionScenarioResult; import com.back.global.ai.exception.AiParsingException; +import com.back.global.ai.exception.AiServiceException; import com.back.global.ai.prompt.BaseScenarioPrompt; import com.back.global.ai.prompt.DecisionScenarioPrompt; import com.back.global.ai.prompt.SituationPrompt; @@ -42,12 +44,14 @@ public class AiServiceImpl implements AiService { private final SituationAiProperties situationAiProperties; private final BaseScenarioAiProperties baseScenarioAiProperties; private final DecisionScenarioAiProperties decisionScenarioAiProperties; + private final ImageAiClient imageAiClient; + private final com.back.global.storage.StorageService storageService; @Override public CompletableFuture generateBaseScenario(BaseLine baseLine) { if (baseLine == null) { return CompletableFuture.failedFuture( - new AiParsingException("BaseLine cannot be null")); + new AiServiceException(com.back.global.exception.ErrorCode.AI_INVALID_REQUEST, "BaseLine cannot be null")); } log.info("Generating base scenario for BaseLine ID: {}", baseLine.getId()); @@ -70,11 +74,24 @@ public CompletableFuture generateBaseScenario(BaseLine baseL baseLine.getId(), e.getMessage(), e); throw new AiParsingException("Failed to parse BaseScenario response: " + e.getMessage()); } + }) + .exceptionally(e -> { + log.error("AI generation failed for BaseLine ID: {}, error: {}", + baseLine.getId(), e.getMessage(), e); + // AiParsingException은 그대로 전파, 나머지만 AiServiceException으로 감쌈 + Throwable cause = e.getCause() != null ? e.getCause() : e; + if (cause instanceof AiParsingException) { + throw (AiParsingException) cause; + } + throw new AiServiceException(com.back.global.exception.ErrorCode.AI_GENERATION_FAILED, + "Failed to generate base scenario: " + e.getMessage()); }); } catch (Exception e) { log.error("Error in generateBaseScenario for BaseLine ID: {}, error: {}", baseLine.getId(), e.getMessage(), e); - return CompletableFuture.failedFuture(e); + return CompletableFuture.failedFuture( + new AiServiceException(com.back.global.exception.ErrorCode.AI_GENERATION_FAILED, + "Unexpected error in base scenario generation: " + e.getMessage())); } } @@ -82,11 +99,11 @@ public CompletableFuture generateBaseScenario(BaseLine baseL public CompletableFuture generateDecisionScenario(DecisionLine decisionLine, Scenario baseScenario) { if (decisionLine == null) { return CompletableFuture.failedFuture( - new AiParsingException("DecisionLine cannot be null")); + new AiServiceException(com.back.global.exception.ErrorCode.AI_INVALID_REQUEST, "DecisionLine cannot be null")); } if (baseScenario == null) { return CompletableFuture.failedFuture( - new AiParsingException("BaseScenario cannot be null")); + new AiServiceException(com.back.global.exception.ErrorCode.AI_INVALID_REQUEST, "BaseScenario cannot be null")); } log.info("Generating Decision scenario for DecisionLine ID: {}", decisionLine.getId()); @@ -110,11 +127,24 @@ public CompletableFuture generateDecisionScenario(Decisi decisionLine.getId(), e.getMessage(), e); throw new AiParsingException("Failed to parse DecisionScenario response: " + e.getMessage()); } + }) + .exceptionally(e -> { + log.error("AI generation failed for DecisionLine ID: {}, error: {}", + decisionLine.getId(), e.getMessage(), e); + // AiParsingException은 그대로 전파, 나머지만 AiServiceException으로 감쌈 + Throwable cause = e.getCause() != null ? e.getCause() : e; + if (cause instanceof AiParsingException) { + throw (AiParsingException) cause; + } + throw new AiServiceException(com.back.global.exception.ErrorCode.AI_GENERATION_FAILED, + "Failed to generate decision scenario: " + e.getMessage()); }); } catch (Exception e) { log.error("Error in generateDecisionScenario for DecisionLine ID: {}, error: {}", decisionLine.getId(), e.getMessage(), e); - return CompletableFuture.failedFuture(e); + return CompletableFuture.failedFuture( + new AiServiceException(com.back.global.exception.ErrorCode.AI_GENERATION_FAILED, + "Unexpected error in decision scenario generation: " + e.getMessage())); } } @@ -137,13 +167,15 @@ public CompletableFuture generateSituation(List previousNo // Input validation if (previousNodes == null || previousNodes.isEmpty()) { return CompletableFuture.failedFuture( - new AiParsingException("Previous nodes cannot be null or empty for situation generation")); + new AiServiceException(com.back.global.exception.ErrorCode.AI_INVALID_REQUEST, + "Previous nodes cannot be null or empty for situation generation")); } // Validate node data quality if (!SituationPrompt.validatePreviousNodes(previousNodes)) { return CompletableFuture.failedFuture( - new AiParsingException("Previous nodes contain invalid data (missing situation or decision)")); + new AiServiceException(com.back.global.exception.ErrorCode.AI_INVALID_REQUEST, + "Previous nodes contain invalid data (missing situation or decision)")); } try { @@ -161,10 +193,10 @@ public CompletableFuture generateSituation(List previousNo aiResponse.length()); // Step 3: JSON에서 상황 텍스트만 추출 - String situation = SituationPrompt.extractSituation(aiResponse); + String situation = SituationPrompt.extractSituation(aiResponse, objectMapper); // 추천 선택지는 로깅만 (Trees 도메인에서 별도 처리) - String recommendedOption = SituationPrompt.extractRecommendedOption(aiResponse); + String recommendedOption = SituationPrompt.extractRecommendedOption(aiResponse, objectMapper); if (recommendedOption != null) { log.debug("AI also provided recommended option: {}", recommendedOption); } @@ -175,11 +207,63 @@ public CompletableFuture generateSituation(List previousNo e.getMessage(), e); throw new AiParsingException("Failed to parse situation response: " + e.getMessage()); } + }) + .exceptionally(e -> { + log.error("AI generation failed for situation, error: {}", + e.getMessage(), e); + // AiParsingException은 그대로 전파, 나머지만 AiServiceException으로 감쌈 + Throwable cause = e.getCause() != null ? e.getCause() : e; + if (cause instanceof AiParsingException) { + throw (AiParsingException) cause; + } + throw new AiServiceException(com.back.global.exception.ErrorCode.AI_GENERATION_FAILED, + "Failed to generate situation: " + e.getMessage()); }); } catch (Exception e) { log.error("Error in generateSituation for {} nodes, error: {}", previousNodes.size(), e.getMessage(), e); - return CompletableFuture.failedFuture(e); + return CompletableFuture.failedFuture( + new AiServiceException(com.back.global.exception.ErrorCode.AI_GENERATION_FAILED, + "Unexpected error in situation generation: " + e.getMessage())); + } + } + + @Override + public CompletableFuture generateImage(String prompt) { + if (!imageAiClient.isEnabled()) { + log.warn("Image AI is disabled. Returning placeholder."); + return CompletableFuture.completedFuture("placeholder-image-url"); + } + + if (prompt == null || prompt.trim().isEmpty()) { + log.warn("Image prompt is empty. Returning placeholder."); + return CompletableFuture.completedFuture("placeholder-image-url"); + } + + log.info("Generating image with prompt: {} (Storage: {})", prompt, storageService.getStorageType()); + + try { + // Stable Diffusion API 호출 → Base64 이미지 생성 + return imageAiClient.generateImage(prompt) + .thenCompose(base64Data -> { + // Base64 데이터를 스토리지에 업로드 → URL 반환 + if (base64Data == null || base64Data.isEmpty() || "placeholder-image-url".equals(base64Data)) { + log.warn("Empty or placeholder Base64 data received from image AI"); + return CompletableFuture.completedFuture("placeholder-image-url"); + } + + log.info("Image generated successfully (Base64 length: {}), uploading to {} storage...", + base64Data.length(), storageService.getStorageType()); + + return storageService.uploadBase64Image(base64Data); + }) + .exceptionally(e -> { + log.warn("Failed to generate or upload image, returning placeholder: {}", e.getMessage()); + return "placeholder-image-url"; + }); + } catch (Exception e) { + log.warn("Error in generateImage, returning placeholder: {}", e.getMessage()); + return CompletableFuture.completedFuture("placeholder-image-url"); } } } diff --git a/back/src/main/java/com/back/global/ai/vector/AIVectorServiceImpl.java b/back/src/main/java/com/back/global/ai/vector/AIVectorServiceImpl.java index da0db92..1beab61 100644 --- a/back/src/main/java/com/back/global/ai/vector/AIVectorServiceImpl.java +++ b/back/src/main/java/com/back/global/ai/vector/AIVectorServiceImpl.java @@ -11,6 +11,7 @@ import com.back.global.ai.config.SituationAiProperties; import com.back.global.ai.dto.AiRequest; import com.back.global.ai.prompt.SituationPrompt; +import com.fasterxml.jackson.databind.ObjectMapper; import lombok.RequiredArgsConstructor; import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; @@ -26,6 +27,7 @@ public class AIVectorServiceImpl implements AIVectorService { private final TextAiClient textAiClient; private final AIVectorServiceSupportDomain support; private final SituationAiProperties props; + private final ObjectMapper objectMapper; // 프로퍼티 바인딩 필드 private int topK = 5; @@ -56,8 +58,8 @@ public AiNextHint generateNextHint(Long userId, Long decisionLineId, List uploadBase64Image(String base64Data) { + return CompletableFuture.supplyAsync(() -> { + try { + // Input validation + if (base64Data == null || base64Data.isEmpty()) { + throw new ApiException(ErrorCode.STORAGE_INVALID_FILE, "Base64 data cannot be null or empty"); + } + + // Base64 디코딩 + byte[] imageBytes = Base64.getDecoder().decode(base64Data); + log.debug("Decoded Base64 image, size: {} bytes", imageBytes.length); + + Path uploadDir = Paths.get(imageAiConfig.getLocalStoragePath()); + if (!Files.exists(uploadDir)) { + Files.createDirectories(uploadDir); + log.info("Created upload directory: {}", uploadDir.toAbsolutePath()); + } + + String fileName = "scenario-" + UUID.randomUUID() + ".jpeg"; + Path filePath = uploadDir.resolve(fileName); + + Files.write(filePath, imageBytes); + + String localUrl = imageAiConfig.getLocalBaseUrl() + "/" + fileName; + + log.info("Image saved locally: {}", filePath.toAbsolutePath()); + log.info("Local URL: {}", localUrl); + + return localUrl; + + } catch (IllegalArgumentException e) { + log.error("Invalid Base64 data: {}", e.getMessage()); + throw new ApiException(ErrorCode.STORAGE_INVALID_FILE, "Invalid Base64 image data: " + e.getMessage()); + } catch (IOException e) { + log.error("Failed to save image locally: {}", e.getMessage(), e); + throw new ApiException(ErrorCode.LOCAL_STORAGE_IO_ERROR, "Failed to save image to local storage: " + e.getMessage()); + } catch (Exception e) { + log.error("Unexpected error during local image upload: {}", e.getMessage(), e); + throw new ApiException(ErrorCode.STORAGE_UPLOAD_FAILED, "Failed to upload image locally: " + e.getMessage()); + } + }); + } + + @Override + public CompletableFuture deleteImage(String imageUrl) { + return CompletableFuture.runAsync(() -> { + try { + // URL에서 파일명 추출 + String fileName = extractFileNameFromUrl(imageUrl); + + // 파일 경로 생성 + Path filePath = Paths.get(imageAiConfig.getLocalStoragePath()).resolve(fileName); + + // 파일 삭제 + if (Files.exists(filePath)) { + Files.delete(filePath); + log.info("Image deleted locally: {}", filePath.toAbsolutePath()); + } else { + log.warn("Image file not found for deletion: {}", filePath.toAbsolutePath()); + } + + } catch (IOException e) { + log.error("Failed to delete local image: {}", e.getMessage(), e); + throw new ApiException(ErrorCode.LOCAL_STORAGE_IO_ERROR, "Failed to delete image from local storage: " + e.getMessage()); + } catch (Exception e) { + log.error("Unexpected error during local image deletion: {}", e.getMessage(), e); + throw new ApiException(ErrorCode.STORAGE_DELETE_FAILED, "Failed to delete image locally: " + e.getMessage()); + } + }); + } + + @Override + public String getStorageType() { + return "local"; + } + + /** + * URL에서 파일명 추출 + * 예: http://localhost:8080/images/scenario-uuid.jpeg → scenario-uuid.jpeg + */ + private String extractFileNameFromUrl(String imageUrl) { + if (imageUrl == null || imageUrl.isEmpty()) { + throw new ApiException(ErrorCode.STORAGE_INVALID_FILE, "Image URL cannot be null or empty"); + } + + String[] parts = imageUrl.split("/"); + return parts[parts.length - 1]; + } +} diff --git a/back/src/main/java/com/back/global/storage/S3StorageService.java b/back/src/main/java/com/back/global/storage/S3StorageService.java new file mode 100644 index 0000000..81e4655 --- /dev/null +++ b/back/src/main/java/com/back/global/storage/S3StorageService.java @@ -0,0 +1,123 @@ +package com.back.global.storage; + +import com.back.global.ai.config.ImageAiConfig; +import com.back.global.exception.ApiException; +import com.back.global.exception.ErrorCode; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.stereotype.Service; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; + +import java.util.Base64; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +/** + * AWS S3 스토리지 서비스 구현체 (프로덕션용) + * storageType="s3"일 때만 활성화됩니다. + * + * CompletableFuture.supplyAsync() 사용하여 메모리 효율적 처리 + * S3Client는 기본 Connection Pool 사용 (리소스 최소화) + * 파일명: UUID 기반으로 충돌 방지 + */ +@Slf4j +@Service +@RequiredArgsConstructor +@ConditionalOnProperty(prefix = "ai.image", name = "storage-type", havingValue = "s3") +public class S3StorageService implements StorageService { + + private final S3Client s3Client; + private final ImageAiConfig imageAiConfig; + + @Override + public CompletableFuture uploadBase64Image(String base64Data) { + return CompletableFuture.supplyAsync(() -> { + try { + // Input validation + if (base64Data == null || base64Data.isEmpty()) { + throw new ApiException(ErrorCode.STORAGE_INVALID_FILE, "Base64 data cannot be null or empty"); + } + + // Base64 디코딩 + byte[] imageBytes = Base64.getDecoder().decode(base64Data); + log.debug("Decoded Base64 image, size: {} bytes", imageBytes.length); + + String fileName = "scenario-" + UUID.randomUUID() + ".jpeg"; + + PutObjectRequest putRequest = PutObjectRequest.builder() + .bucket(imageAiConfig.getS3BucketName()) + .key(fileName) + .contentType("image/jpeg") + .build(); + + // S3 업로드 + s3Client.putObject(putRequest, RequestBody.fromBytes(imageBytes)); + + String s3Url = String.format( + "https://%s.s3.%s.amazonaws.com/%s", + imageAiConfig.getS3BucketName(), + imageAiConfig.getS3Region(), + fileName + ); + + log.info("Image uploaded to S3 successfully: {}", s3Url); + return s3Url; + + } catch (IllegalArgumentException e) { + log.error("Invalid Base64 data: {}", e.getMessage()); + throw new ApiException(ErrorCode.STORAGE_INVALID_FILE, "Invalid Base64 image data: " + e.getMessage()); + } catch (S3Exception e) { + log.error("S3 service error: {}", e.getMessage(), e); + throw new ApiException(ErrorCode.S3_CONNECTION_FAILED, "S3 upload failed: " + e.awsErrorDetails().errorMessage()); + } catch (Exception e) { + log.error("S3 upload failed: {}", e.getMessage(), e); + throw new ApiException(ErrorCode.STORAGE_UPLOAD_FAILED, "Failed to upload image to S3: " + e.getMessage()); + } + }); + } + + @Override + public CompletableFuture deleteImage(String imageUrl) { + return CompletableFuture.runAsync(() -> { + try { + String fileName = extractFileNameFromUrl(imageUrl); + + // S3 삭제 요청 + DeleteObjectRequest deleteRequest = DeleteObjectRequest.builder() + .bucket(imageAiConfig.getS3BucketName()) + .key(fileName) + .build(); + + s3Client.deleteObject(deleteRequest); + log.info("Image deleted from S3: {}", fileName); + + } catch (S3Exception e) { + log.error("S3 service error during deletion: {}", e.getMessage(), e); + throw new ApiException(ErrorCode.S3_CONNECTION_FAILED, "S3 delete failed: " + e.awsErrorDetails().errorMessage()); + } catch (Exception e) { + log.error("Failed to delete image from S3: {}", e.getMessage(), e); + throw new ApiException(ErrorCode.STORAGE_DELETE_FAILED, "Failed to delete image from S3: " + e.getMessage()); + } + }); + } + + @Override + public String getStorageType() { + return "s3"; + } + + //S3 URL에서 파일명 추출, 예:https://bucket.s3.region.amazonaws.com/scenario-uuid.jpeg → scenario-uuid.jpeg + private String extractFileNameFromUrl(String imageUrl) { + if (imageUrl == null || imageUrl.isEmpty()) { + throw new ApiException(ErrorCode.STORAGE_INVALID_FILE, "Image URL cannot be null or empty"); + } + + String[] parts = imageUrl.split("/"); + return parts[parts.length - 1]; + } +} diff --git a/back/src/main/java/com/back/global/storage/StorageService.java b/back/src/main/java/com/back/global/storage/StorageService.java new file mode 100644 index 0000000..b90d5e4 --- /dev/null +++ b/back/src/main/java/com/back/global/storage/StorageService.java @@ -0,0 +1,25 @@ +package com.back.global.storage; + +import java.util.concurrent.CompletableFuture; + +/** + * 스토리지 서비스 인터페이스 + * 이미지 업로드/삭제 기능을 추상화합니다. + * + * 구현체: + * - S3StorageService: AWS S3에 업로드 (프로덕션) + * - LocalStorageService: 로컬 파일 시스템에 저장 (개발) + * + * CompletableFuture를 사용하여 AiService와 동일한 비동기 패턴 유지 + */ +public interface StorageService { + + // @return 업로드된 이미지 URL (S3 URL 또는 로컬 URL) + CompletableFuture uploadBase64Image(String base64Data); + + // @return 삭제 완료 Future + CompletableFuture deleteImage(String imageUrl); + + // @return "s3" 또는 "local" + String getStorageType(); +} diff --git a/back/src/main/resources/application-prod.yml b/back/src/main/resources/application-prod.yml index 07ddb37..9d5902a 100644 --- a/back/src/main/resources/application-prod.yml +++ b/back/src/main/resources/application-prod.yml @@ -73,4 +73,12 @@ custom: site: baseDomain: "${custom.prod.baseDomain}" frontUrl: "${custom.prod.frontUrl}" - backUrl: "${custom.prod.backUrl}" \ No newline at end of file + backUrl: "${custom.prod.backUrl}" + +# AI Services Configuration (Production) +ai: + image: + enabled: false # 프로덕션 환경에서는 활성화 + storage-type: s3 # S3 스토리지 사용 + s3-bucket-name: ${AWS_S3_BUCKET_NAME} + s3-region: ${AWS_S3_REGION} \ No newline at end of file diff --git a/back/src/main/resources/application.yml b/back/src/main/resources/application.yml index 73266b8..7b027b4 100644 --- a/back/src/main/resources/application.yml +++ b/back/src/main/resources/application.yml @@ -58,9 +58,10 @@ spring: task: execution: pool: - core-size: 2 # AWS t2.small (1 vCPU) 기준 기본 쓰레드 - max-size: 4 # 최대 쓰레드 (버스트 대응) - queue-capacity: 50 # 대기 큐 크기 + core-size: 2 # AWS Small 티어 (1-2 vCPU) 기준 기본 스레드 + max-size: 4 # 최대 스레드 (버스트 대응) + queue-capacity: 100 # 대기 큐 크기 (메모리 2GB 고려, 버퍼링 증가) + await-termination-seconds: 60 # 종료 대기 시간 (초) thread-name-prefix: "async-ai-" logging: level: @@ -86,12 +87,20 @@ ai: base-url: https://generativelanguage.googleapis.com timeout-seconds: 30 max-retries: 3 + retry-delay-seconds: 2 # 재시도 간격 (초) image: - enabled: false - provider: placeholder + enabled: true + provider: stable-diffusion + api-key: ${STABILITY_API_KEY} + base-url: https://api.stability.ai timeout-seconds: 60 max-retries: 3 - # TODO: 추후 이미지 AI 설정 추가 예정 + retry-delay-seconds: 2 # 재시도 간격 (초) + # Storage 설정 (local 또는 s3) + storage-type: local # 개발 환경: local 사용 + # 로컬 스토리지 설정 + local-storage-path: ./uploads/images + local-base-url: http://localhost:8080/images situation: topK: 3 # 성능 최적화 (5 → 3) maxOutputTokens: 128 # 성능 최적화 완료 diff --git a/back/src/main/resources/db/migration/V4__add_indexes_const_on_scenarios.sql b/back/src/main/resources/db/migration/V4__add_indexes_const_on_scenarios.sql new file mode 100644 index 0000000..bc989d8 --- /dev/null +++ b/back/src/main/resources/db/migration/V4__add_indexes_const_on_scenarios.sql @@ -0,0 +1,11 @@ +-- ======================================== +-- Scenario 테이블 인덱스 및 제약조건 정리 +-- ======================================== + +-- 기존 V1의 constraint 이름 변경 (uq_scenario_decision_line -> uk_scenario_decision_line) +ALTER TABLE scenarios DROP CONSTRAINT IF EXISTS uq_scenario_decision_line; +ALTER TABLE scenarios ADD CONSTRAINT uk_scenario_decision_line UNIQUE (decision_line_id); + +-- 인덱스 추가 +CREATE INDEX IF NOT EXISTS idx_scenario_user_status ON scenarios (user_id, status, created_date); +CREATE INDEX IF NOT EXISTS idx_scenario_baseline ON scenarios (base_line_id); \ No newline at end of file diff --git a/back/src/test/java/com/back/domain/node/controller/AiOnceDelegateTestConfig.java b/back/src/test/java/com/back/domain/node/controller/AiOnceDelegateTestConfig.java index c97f4d7..384764a 100644 --- a/back/src/test/java/com/back/domain/node/controller/AiOnceDelegateTestConfig.java +++ b/back/src/test/java/com/back/domain/node/controller/AiOnceDelegateTestConfig.java @@ -9,6 +9,7 @@ import com.back.global.ai.vector.AIVectorService; import com.back.global.ai.vector.AIVectorServiceImpl; import com.back.global.ai.vector.AIVectorServiceSupportDomain; +import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.boot.test.context.TestConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; @@ -33,9 +34,10 @@ public AIVectorService aiOnceDelegate( TextAiClient textAiClient, AIVectorServiceSupportDomain support, SituationAiProperties props, + ObjectMapper objectMapper, AiCallBudget budget ) { - AIVectorService real = new AIVectorServiceImpl(textAiClient, support, props); + AIVectorService real = new AIVectorServiceImpl(textAiClient, support, props, objectMapper); AIVectorService stub = (u, d, nodes) -> new AIVectorService.AiNextHint("테스트-상황(한 문장)", "테스트-추천"); return (userId, lineId, orderedNodes) -> budget.consume() ? real.generateNextHint(userId, lineId, orderedNodes) diff --git a/back/src/test/java/com/back/domain/scenario/controller/ScenarioControllerTest.java b/back/src/test/java/com/back/domain/scenario/controller/ScenarioControllerTest.java index 0c2a8f7..eccdcc9 100644 --- a/back/src/test/java/com/back/domain/scenario/controller/ScenarioControllerTest.java +++ b/back/src/test/java/com/back/domain/scenario/controller/ScenarioControllerTest.java @@ -6,14 +6,13 @@ import com.back.domain.scenario.entity.ScenarioStatus; import com.back.domain.scenario.entity.Type; import com.back.domain.scenario.service.ScenarioService; +import com.back.domain.user.entity.User; import com.back.global.common.PageResponse; import com.back.global.exception.ApiException; import com.back.global.exception.ErrorCode; +import com.back.global.security.CustomUserDetails; import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.*; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc; import org.springframework.boot.test.context.SpringBootTest; @@ -32,6 +31,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.multipart; import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; @@ -40,11 +41,11 @@ /** * ScenarioController 통합 테스트. - * 세션 기반 인증이 구현되었지만 테스트에서는 필터를 비활성화하고 - * Service를 모킹하여 테스트합니다. + * 세션 기반 인증을 활성화하고 MockMvc의 .with(user())를 통해 + * 인증된 사용자로 API를 테스트합니다. Service는 모킹하여 테스트합니다. */ @SpringBootTest -@AutoConfigureMockMvc(addFilters = false) // 인증 필터 비활성화로 테스트 단순화 +@AutoConfigureMockMvc // Security 필터 활성화하여 @AuthenticationPrincipal 정상 동작 @ActiveProfiles("test") @Transactional @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -60,6 +61,38 @@ class ScenarioControllerTest { @MockBean private ScenarioService scenarioService; + // Mock 사용자 생성 + private CustomUserDetails mockUserDetails; + + @BeforeEach + void setUp() { + // Mock User 생성 (ID = 1L) + User mockUser = User.builder() + .email("test@example.com") + .username("테스트사용자") + .nickname("테스트닉네임") + .birthdayAt(LocalDateTime.of(1990, 1, 1, 0, 0)) + .role(com.back.domain.user.entity.Role.USER) + .build(); + + // BaseEntity의 id는 Reflection으로 설정 + try { + java.lang.reflect.Field idField = mockUser.getClass().getSuperclass().getDeclaredField("id"); + idField.setAccessible(true); + idField.set(mockUser, 1L); + + // Reflection이 제대로 작동했는지 검증 + Long verifyId = (Long) idField.get(mockUser); + if (verifyId == null || !verifyId.equals(1L)) { + throw new RuntimeException("Failed to set user ID via Reflection, got: " + verifyId); + } + } catch (Exception e) { + throw new RuntimeException("Failed to set mock user ID", e); + } + + mockUserDetails = new CustomUserDetails(mockUser); + } + @Nested @DisplayName("시나리오 생성") class CreateScenario { @@ -111,7 +144,9 @@ class CreateScenario { mockMvc.perform(multipart("/api/v1/scenarios") .file(scenarioPart) .file(lastDecisionPart) - .with(req -> { req.setMethod("POST"); return req; })) + .with(req -> { req.setMethod("POST"); return req; }) + .with(csrf()) + .with(user(mockUserDetails))) .andDo(print()) .andExpect(status().isCreated()) .andExpect(jsonPath("$.scenarioId").value(1001)) @@ -132,7 +167,9 @@ class CreateScenario { // When & Then mockMvc.perform(multipart("/api/v1/scenarios") .file(scenarioPart) - .with(req -> { req.setMethod("POST"); return req; })) + .with(req -> { req.setMethod("POST"); return req; }) + .with(csrf()) + .with(user(mockUserDetails))) .andDo(print()) .andExpect(status().isBadRequest()); } @@ -176,7 +213,9 @@ class CreateScenario { mockMvc.perform(multipart("/api/v1/scenarios") .file(scenarioPart) .file(lastDecisionPart) - .with(req -> { req.setMethod("POST"); return req; })) + .with(req -> { req.setMethod("POST"); return req; }) + .with(csrf()) + .with(user(mockUserDetails))) .andDo(print()) .andExpect(status().isNotFound()); } @@ -201,7 +240,8 @@ class GetScenarioStatus { .willReturn(mockResponse); // When & Then - mockMvc.perform(get("/api/v1/scenarios/{scenarioId}/status", scenarioId)) + mockMvc.perform(get("/api/v1/scenarios/{scenarioId}/status", scenarioId) + .with(user(mockUserDetails))) .andDo(print()) .andExpect(status().isOk()) .andExpect(jsonPath("$.scenarioId").value(scenarioId)) @@ -219,7 +259,8 @@ class GetScenarioStatus { .willThrow(new ApiException(ErrorCode.SCENARIO_NOT_FOUND)); // When & Then - mockMvc.perform(get("/api/v1/scenarios/{scenarioId}/status", scenarioId)) + mockMvc.perform(get("/api/v1/scenarios/{scenarioId}/status", scenarioId) + .with(user(mockUserDetails))) .andDo(print()) .andExpect(status().isNotFound()); } @@ -259,7 +300,8 @@ class GetScenarioDetail { .willReturn(mockResponse); // When & Then - mockMvc.perform(get("/api/v1/scenarios/info/{scenarioId}", scenarioId)) + mockMvc.perform(get("/api/v1/scenarios/info/{scenarioId}", scenarioId) + .with(user(mockUserDetails))) .andDo(print()) .andExpect(status().isOk()) .andExpect(jsonPath("$.scenarioId").value(scenarioId)) @@ -292,7 +334,8 @@ class GetScenarioTimeline { .willReturn(mockResponse); // When & Then - mockMvc.perform(get("/api/v1/scenarios/{scenarioId}/timeline", scenarioId)) + mockMvc.perform(get("/api/v1/scenarios/{scenarioId}/timeline", scenarioId) + .with(user(mockUserDetails))) .andDo(print()) .andExpect(status().isOk()) .andExpect(jsonPath("$.scenarioId").value(scenarioId)) @@ -334,7 +377,8 @@ class GetBaselines { // When & Then mockMvc.perform(get("/api/v1/scenarios/baselines") .param("page", "0") - .param("size", "10")) + .param("size", "10") + .with(user(mockUserDetails))) .andDo(print()) .andExpect(status().isOk()) .andExpect(jsonPath("$.items").isArray()) @@ -377,7 +421,8 @@ class CompareScenarios { .willReturn(mockResponse); // When & Then - mockMvc.perform(get("/api/v1/scenarios/compare/{baseId}/{compareId}", baseId, compareId)) + mockMvc.perform(get("/api/v1/scenarios/compare/{baseId}/{compareId}", baseId, compareId) + .with(user(mockUserDetails))) .andDo(print()) .andExpect(status().isOk()) .andExpect(jsonPath("$.baseScenarioId").value(baseId)) diff --git a/back/src/test/java/com/back/global/ai/client/image/StableDiffusionImageClientTest.java b/back/src/test/java/com/back/global/ai/client/image/StableDiffusionImageClientTest.java new file mode 100644 index 0000000..17ca39c --- /dev/null +++ b/back/src/test/java/com/back/global/ai/client/image/StableDiffusionImageClientTest.java @@ -0,0 +1,312 @@ +package com.back.global.ai.client.image; + +import com.back.global.ai.config.ImageAiConfig; +import com.back.global.ai.exception.AiServiceException; +import com.back.global.exception.ErrorCode; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.BDDMockito.*; + +/** + * StableDiffusionImageClient 단위 테스트. + * Stable Diffusion API 연동 로직을 검증합니다. + */ +@ExtendWith(MockitoExtension.class) +@DisplayName("StableDiffusionImageClient 단위 테스트") +class StableDiffusionImageClientTest { + + @Mock + private ImageAiConfig imageAiConfig; + + @Mock + private WebClient webClient; + + @Mock + private WebClient.RequestBodyUriSpec requestBodyUriSpec; + + @Mock + private WebClient.RequestBodySpec requestBodySpec; + + @Mock + private WebClient.RequestHeadersSpec requestHeadersSpec; + + @Mock + private WebClient.ResponseSpec responseSpec; + + private StableDiffusionImageClient client; + + @BeforeEach + void setUp() { + lenient().when(imageAiConfig.getBaseUrl()).thenReturn("https://api.stability.ai"); + lenient().when(imageAiConfig.getApiKey()).thenReturn("test-api-key"); + lenient().when(imageAiConfig.getTimeoutSeconds()).thenReturn(60); + lenient().when(imageAiConfig.getMaxRetries()).thenReturn(3); + lenient().when(imageAiConfig.isEnabled()).thenReturn(true); + + client = new StableDiffusionImageClient(imageAiConfig, webClient); + } + + @Nested + @DisplayName("이미지 생성") + class GenerateImageTests { + + @Test + @DisplayName("성공 - 기본 프롬프트로 이미지 생성") + void generateImage_성공_기본_프롬프트() throws ExecutionException, InterruptedException { + // Given + String prompt = "A beautiful landscape"; + String mockResponse = """ + { + "artifacts": [{ + "base64": "mock-base64-data", + "finishReason": "SUCCESS" + }] + } + """; + + // WebClient Mock 체인 설정 + given(webClient.post()).willReturn(requestBodyUriSpec); + given(requestBodyUriSpec.uri(anyString())).willReturn(requestBodySpec); + given(requestBodySpec.header(anyString(), anyString())).willReturn(requestBodySpec); + given(requestBodySpec.body(any())).willReturn(requestHeadersSpec); + given(requestHeadersSpec.retrieve()).willReturn(responseSpec); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(mockResponse)); + + // When + CompletableFuture future = client.generateImage(prompt); + String result = future.get(); + + // Then + assertThat(result).isNotNull(); + assertThat(result).isEqualTo("mock-base64-data"); + + verify(webClient, times(1)).post(); + } + + @Test + @DisplayName("성공 - 옵션과 함께 이미지 생성") + void generateImage_성공_옵션_포함() throws ExecutionException, InterruptedException { + // Given + String prompt = "A beautiful landscape"; + Map options = Map.of( + "aspect_ratio", "16:9", + "seed", 12345, + "negative_prompt", "ugly, blurry" + ); + String mockResponse = """ + { + "artifacts": [{ + "base64": "mock-base64-data-with-options", + "finishReason": "SUCCESS" + }] + } + """; + + given(webClient.post()).willReturn(requestBodyUriSpec); + given(requestBodyUriSpec.uri(anyString())).willReturn(requestBodySpec); + given(requestBodySpec.header(anyString(), anyString())).willReturn(requestBodySpec); + given(requestBodySpec.body(any())).willReturn(requestHeadersSpec); + given(requestHeadersSpec.retrieve()).willReturn(responseSpec); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(mockResponse)); + + // When + CompletableFuture future = client.generateImage(prompt, options); + String result = future.get(); + + // Then + assertThat(result).isEqualTo("mock-base64-data-with-options"); + } + + @Test + @DisplayName("성공 - 프롬프트 자동 향상 (만화풍 스타일)") + void generateImage_성공_프롬프트_자동향상() { + // Given + String prompt = "A cat"; + String mockResponse = """ + { + "artifacts": [{ + "base64": "enhanced-base64-data", + "finishReason": "SUCCESS" + }] + } + """; + + given(webClient.post()).willReturn(requestBodyUriSpec); + given(requestBodyUriSpec.uri(anyString())).willReturn(requestBodySpec); + given(requestBodySpec.header(anyString(), anyString())).willReturn(requestBodySpec); + given(requestBodySpec.body(any())).willReturn(requestHeadersSpec); + given(requestHeadersSpec.retrieve()).willReturn(responseSpec); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(mockResponse)); + + // When + client.generateImage(prompt); + + // Then + // 프롬프트가 자동으로 향상되었는지 확인 (내부적으로 ", cartoon style, ..." 추가됨) + verify(requestBodySpec, times(2)).header(anyString(), anyString()); // Authorization + Accept + } + + @Test + @DisplayName("실패 - finishReason이 SUCCESS가 아닌 경우") + void generateImage_실패_finishReason_실패() { + // Given + String prompt = "test prompt"; + String mockResponse = """ + { + "artifacts": [{ + "base64": "some-data", + "finishReason": "CONTENT_FILTERED" + }] + } + """; + + given(webClient.post()).willReturn(requestBodyUriSpec); + given(requestBodyUriSpec.uri(anyString())).willReturn(requestBodySpec); + given(requestBodySpec.header(anyString(), anyString())).willReturn(requestBodySpec); + given(requestBodySpec.body(any())).willReturn(requestHeadersSpec); + given(requestHeadersSpec.retrieve()).willReturn(responseSpec); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(mockResponse)); + + // When & Then + assertThatThrownBy(() -> client.generateImage(prompt).get()) + .isInstanceOf(ExecutionException.class) + .hasCauseInstanceOf(AiServiceException.class) + .hasMessageContaining("CONTENT_FILTERED"); + } + + @Test + @DisplayName("실패 - 응답 구조가 잘못된 경우") + void generateImage_실패_잘못된_응답구조() { + // Given + String prompt = "test prompt"; + String invalidResponse = """ + { + "error": "Invalid request" + } + """; + + given(webClient.post()).willReturn(requestBodyUriSpec); + given(requestBodyUriSpec.uri(anyString())).willReturn(requestBodySpec); + given(requestBodySpec.header(anyString(), anyString())).willReturn(requestBodySpec); + given(requestBodySpec.body(any())).willReturn(requestHeadersSpec); + given(requestHeadersSpec.retrieve()).willReturn(responseSpec); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(invalidResponse)); + + // When & Then + assertThatThrownBy(() -> client.generateImage(prompt).get()) + .isInstanceOf(ExecutionException.class) + .hasCauseInstanceOf(AiServiceException.class) + .cause() + .hasFieldOrPropertyWithValue("errorCode", ErrorCode.AI_GENERATION_FAILED); + } + + @Test + @DisplayName("실패 - base64 필드가 없는 경우") + void generateImage_실패_base64_없음() { + // Given + String prompt = "test prompt"; + String mockResponse = """ + { + "artifacts": [{ + "finishReason": "SUCCESS" + }] + } + """; + + given(webClient.post()).willReturn(requestBodyUriSpec); + given(requestBodyUriSpec.uri(anyString())).willReturn(requestBodySpec); + given(requestBodySpec.header(anyString(), anyString())).willReturn(requestBodySpec); + given(requestBodySpec.body(any())).willReturn(requestHeadersSpec); + given(requestHeadersSpec.retrieve()).willReturn(responseSpec); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(mockResponse)); + + // When & Then + assertThatThrownBy(() -> client.generateImage(prompt).get()) + .isInstanceOf(ExecutionException.class) + .hasCauseInstanceOf(AiServiceException.class); + } + + @Test + @DisplayName("실패 - JSON 파싱 에러") + void generateImage_실패_JSON_파싱_에러() { + // Given + String prompt = "test prompt"; + String invalidJson = "This is not JSON"; + + given(webClient.post()).willReturn(requestBodyUriSpec); + given(requestBodyUriSpec.uri(anyString())).willReturn(requestBodySpec); + given(requestBodySpec.header(anyString(), anyString())).willReturn(requestBodySpec); + given(requestBodySpec.body(any())).willReturn(requestHeadersSpec); + given(requestHeadersSpec.retrieve()).willReturn(responseSpec); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(invalidJson)); + + // When & Then + assertThatThrownBy(() -> client.generateImage(prompt).get()) + .isInstanceOf(ExecutionException.class) + .hasCauseInstanceOf(AiServiceException.class) + .hasMessageContaining("Failed to parse image generation response"); + } + } + + @Nested + @DisplayName("API 설정") + class ConfigurationTests { + + @Test + @DisplayName("성공 - API 엔드포인트 확인") + void configuration_성공_API_엔드포인트() { + // Given + String prompt = "test"; + String mockResponse = """ + { + "artifacts": [{ + "base64": "data", + "finishReason": "SUCCESS" + }] + } + """; + + ArgumentCaptor uriCaptor = ArgumentCaptor.forClass(String.class); + + given(webClient.post()).willReturn(requestBodyUriSpec); + given(requestBodyUriSpec.uri(anyString())).willReturn(requestBodySpec); + given(requestBodySpec.header(anyString(), anyString())).willReturn(requestBodySpec); + given(requestBodySpec.body(any())).willReturn(requestHeadersSpec); + given(requestHeadersSpec.retrieve()).willReturn(responseSpec); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(mockResponse)); + + // When + client.generateImage(prompt); + + // Then + verify(requestBodyUriSpec).uri(uriCaptor.capture()); + assertThat(uriCaptor.getValue()).contains("/v2beta/stable-image/generate/sd3"); + } + + @Test + @DisplayName("성공 - isEnabled() 확인") + void configuration_성공_isEnabled() { + // When + boolean enabled = client.isEnabled(); + + // Then + assertThat(enabled).isTrue(); + verify(imageAiConfig, times(1)).isEnabled(); + } + } +} diff --git a/back/src/test/java/com/back/global/ai/service/AiServiceImplTest.java b/back/src/test/java/com/back/global/ai/service/AiServiceImplTest.java index 1745a9f..45e7e15 100644 --- a/back/src/test/java/com/back/global/ai/service/AiServiceImplTest.java +++ b/back/src/test/java/com/back/global/ai/service/AiServiceImplTest.java @@ -17,6 +17,7 @@ import com.back.global.ai.dto.result.BaseScenarioResult; import com.back.global.ai.dto.result.DecisionScenarioResult; import com.back.global.ai.exception.AiParsingException; +import com.back.global.ai.exception.AiServiceException; import com.back.global.baseentity.BaseEntity; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.BeforeEach; @@ -55,6 +56,8 @@ class AiServiceImplTest { @Mock private ObjectMapper objectMapper; + private final ObjectMapper realObjectMapper = new ObjectMapper(); + @Mock private SceneTypeRepository sceneTypeRepository; @@ -212,7 +215,7 @@ void generateBaseScenario_Fail_NullBaseLine() { // when & then assertThatThrownBy(() -> aiService.generateBaseScenario(null).get()) .isInstanceOf(ExecutionException.class) - .hasCauseInstanceOf(AiParsingException.class) + .hasCauseInstanceOf(AiServiceException.class) .hasMessageContaining("BaseLine cannot be null"); verify(textAiClient, never()).generateText(any(AiRequest.class)); @@ -306,7 +309,7 @@ void generateDecisionScenario_Fail_NullDecisionLine() { // when & then assertThatThrownBy(() -> aiService.generateDecisionScenario(null, testScenario).get()) .isInstanceOf(ExecutionException.class) - .hasCauseInstanceOf(AiParsingException.class) + .hasCauseInstanceOf(AiServiceException.class) .hasMessageContaining("DecisionLine cannot be null"); verify(textAiClient, never()).generateText(any(AiRequest.class)); @@ -318,7 +321,7 @@ void generateDecisionScenario_Fail_NullBaseScenario() { // when & then assertThatThrownBy(() -> aiService.generateDecisionScenario(testDecisionLine, null).get()) .isInstanceOf(ExecutionException.class) - .hasCauseInstanceOf(AiParsingException.class) + .hasCauseInstanceOf(AiServiceException.class) .hasMessageContaining("BaseScenario cannot be null"); verify(textAiClient, never()).generateText(any(AiRequest.class)); @@ -383,6 +386,8 @@ class GenerateSituationTests { @DisplayName("성공 - 유효한 이전 선택들로 새로운 상황 생성") void generateSituation_Success() throws Exception { // given + setField(aiService, "objectMapper", realObjectMapper); // 실제 ObjectMapper 사용 + List previousNodes = testDecisionLine.getDecisionNodes(); String mockAiResponse = """ { @@ -406,13 +411,19 @@ void generateSituation_Success() throws Exception { verify(textAiClient, times(1)).generateText(any(AiRequest.class)); } + private void setField(Object target, String fieldName, Object value) throws Exception { + var field = AiServiceImpl.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(target, value); + } + @Test @DisplayName("실패 - previousNodes가 null인 경우") void generateSituation_Fail_NullPreviousNodes() { // when & then assertThatThrownBy(() -> aiService.generateSituation(null).get()) .isInstanceOf(ExecutionException.class) - .hasCauseInstanceOf(AiParsingException.class) + .hasCauseInstanceOf(AiServiceException.class) .hasMessageContaining("Previous nodes cannot be null or empty"); verify(textAiClient, never()).generateText(any(AiRequest.class)); @@ -424,7 +435,7 @@ void generateSituation_Fail_EmptyPreviousNodes() { // when & then assertThatThrownBy(() -> aiService.generateSituation(List.of()).get()) .isInstanceOf(ExecutionException.class) - .hasCauseInstanceOf(AiParsingException.class) + .hasCauseInstanceOf(AiServiceException.class) .hasMessageContaining("Previous nodes cannot be null or empty"); verify(textAiClient, never()).generateText(any(AiRequest.class)); @@ -446,7 +457,7 @@ void generateSituation_Fail_InvalidNodeData() { // when & then assertThatThrownBy(() -> aiService.generateSituation(invalidNodes).get()) .isInstanceOf(ExecutionException.class) - .hasCauseInstanceOf(AiParsingException.class) + .hasCauseInstanceOf(AiServiceException.class) .hasMessageContaining("Previous nodes contain invalid data"); verify(textAiClient, never()).generateText(any(AiRequest.class)); @@ -456,6 +467,8 @@ void generateSituation_Fail_InvalidNodeData() { @DisplayName("성공 - AI 응답에서 상황 추출 (JSON 형식)") void generateSituation_Success_ExtractSituationFromJson() throws Exception { // given + setField(aiService, "objectMapper", realObjectMapper); // 실제 ObjectMapper 사용 + List previousNodes = testDecisionLine.getDecisionNodes(); String mockAiResponse = """ { @@ -484,15 +497,191 @@ void generateSituation_Success_ExtractSituationFromJson() throws Exception { @DisplayName("이미지 생성 테스트") class GenerateImageTests { + @Mock + private com.back.global.ai.client.image.ImageAiClient imageAiClient; + + @Mock + private com.back.global.storage.StorageService storageService; + + @Test + @DisplayName("실패 - Image AI 비활성화 시 placeholder 반환") + void generateImage_Fail_ImageAiDisabled() throws Exception { + // given + String prompt = "test prompt"; + + // ImageAiClient를 aiService에 수동 주입 (필드 설정) + setField(aiService, "imageAiClient", imageAiClient); + + given(imageAiClient.isEnabled()).willReturn(false); + + // when + CompletableFuture future = aiService.generateImage(prompt); + String result = future.get(); + + // then + assertThat(result).isEqualTo("placeholder-image-url"); + verify(imageAiClient, times(1)).isEnabled(); + verify(imageAiClient, never()).generateImage(anyString()); + } + + @Test + @DisplayName("실패 - 빈 프롬프트 시 placeholder 반환") + void generateImage_Fail_EmptyPrompt() throws Exception { + // given + setField(aiService, "imageAiClient", imageAiClient); + given(imageAiClient.isEnabled()).willReturn(true); + + // when - 빈 프롬프트 + CompletableFuture future1 = aiService.generateImage(""); + String result1 = future1.get(); + + // when - 공백만 있는 프롬프트 + CompletableFuture future2 = aiService.generateImage(" "); + String result2 = future2.get(); + + // when - null 프롬프트 + CompletableFuture future3 = aiService.generateImage(null); + String result3 = future3.get(); + + // then + assertThat(result1).isEqualTo("placeholder-image-url"); + assertThat(result2).isEqualTo("placeholder-image-url"); + assertThat(result3).isEqualTo("placeholder-image-url"); + + verify(imageAiClient, never()).generateImage(anyString()); + } + + @Test + @DisplayName("성공 - 전체 플로우 (AI 생성 → 스토리지 업로드)") + void generateImage_Success_FullFlow() throws Exception { + // given + String prompt = "A beautiful sunset over mountains"; + String base64Data = "mock-base64-data"; + String uploadedUrl = "https://test-bucket.s3.ap-northeast-2.amazonaws.com/scenario-test.jpeg"; + + setField(aiService, "imageAiClient", imageAiClient); + setField(aiService, "storageService", storageService); + + given(imageAiClient.isEnabled()).willReturn(true); + given(imageAiClient.generateImage(prompt)) + .willReturn(CompletableFuture.completedFuture(base64Data)); + given(storageService.getStorageType()).willReturn("s3"); + given(storageService.uploadBase64Image(base64Data)) + .willReturn(CompletableFuture.completedFuture(uploadedUrl)); + + // when + CompletableFuture future = aiService.generateImage(prompt); + String result = future.get(); + + // then + assertThat(result).isEqualTo(uploadedUrl); + verify(imageAiClient, times(1)).isEnabled(); + verify(imageAiClient, times(1)).generateImage(prompt); + verify(storageService, times(1)).uploadBase64Image(base64Data); + } + + @Test + @DisplayName("실패 - AI가 빈 Base64 반환 시 placeholder") + void generateImage_Fail_EmptyBase64FromAi() throws Exception { + // given + String prompt = "test prompt"; + + setField(aiService, "imageAiClient", imageAiClient); + setField(aiService, "storageService", storageService); + + given(imageAiClient.isEnabled()).willReturn(true); + given(imageAiClient.generateImage(prompt)) + .willReturn(CompletableFuture.completedFuture("")); // 빈 문자열 + given(storageService.getStorageType()).willReturn("local"); + + // when + CompletableFuture future = aiService.generateImage(prompt); + String result = future.get(); + + // then + assertThat(result).isEqualTo("placeholder-image-url"); + verify(imageAiClient, times(1)).generateImage(prompt); + verify(storageService, never()).uploadBase64Image(anyString()); + } + + @Test + @DisplayName("실패 - AI가 placeholder 반환 시 placeholder") + void generateImage_Fail_PlaceholderFromAi() throws Exception { + // given + String prompt = "test prompt"; + + setField(aiService, "imageAiClient", imageAiClient); + setField(aiService, "storageService", storageService); + + given(imageAiClient.isEnabled()).willReturn(true); + given(imageAiClient.generateImage(prompt)) + .willReturn(CompletableFuture.completedFuture("placeholder-image-url")); + given(storageService.getStorageType()).willReturn("local"); + + // when + CompletableFuture future = aiService.generateImage(prompt); + String result = future.get(); + + // then + assertThat(result).isEqualTo("placeholder-image-url"); + verify(storageService, never()).uploadBase64Image(anyString()); + } + + @Test + @DisplayName("실패 - AI 에러 시 placeholder 반환") + void generateImage_Fail_AiError() throws Exception { + // given + String prompt = "test prompt"; + + setField(aiService, "imageAiClient", imageAiClient); + setField(aiService, "storageService", storageService); + + given(imageAiClient.isEnabled()).willReturn(true); + given(imageAiClient.generateImage(prompt)) + .willReturn(CompletableFuture.failedFuture(new RuntimeException("AI API Error"))); + given(storageService.getStorageType()).willReturn("s3"); + + // when + CompletableFuture future = aiService.generateImage(prompt); + String result = future.get(); + + // then + assertThat(result).isEqualTo("placeholder-image-url"); + verify(imageAiClient, times(1)).generateImage(prompt); + verify(storageService, never()).uploadBase64Image(anyString()); + } + @Test - @DisplayName("기본 구현 - placeholder 이미지 URL 반환") - void generateImage_DefaultImplementation() throws Exception { + @DisplayName("실패 - 스토리지 업로드 에러 시 placeholder 반환") + void generateImage_Fail_StorageError() throws Exception { + // given + String prompt = "test prompt"; + String base64Data = "mock-base64-data"; + + setField(aiService, "imageAiClient", imageAiClient); + setField(aiService, "storageService", storageService); + + given(imageAiClient.isEnabled()).willReturn(true); + given(imageAiClient.generateImage(prompt)) + .willReturn(CompletableFuture.completedFuture(base64Data)); + given(storageService.getStorageType()).willReturn("s3"); + given(storageService.uploadBase64Image(base64Data)) + .willReturn(CompletableFuture.failedFuture(new RuntimeException("S3 Upload Error"))); + // when - CompletableFuture future = aiService.generateImage("test prompt"); + CompletableFuture future = aiService.generateImage(prompt); String result = future.get(); // then assertThat(result).isEqualTo("placeholder-image-url"); + verify(imageAiClient, times(1)).generateImage(prompt); + verify(storageService, times(1)).uploadBase64Image(base64Data); + } + + private void setField(Object target, String fieldName, Object value) throws Exception { + var field = AiServiceImpl.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(target, value); } } diff --git a/back/src/test/java/com/back/global/storage/LocalStorageServiceTest.java b/back/src/test/java/com/back/global/storage/LocalStorageServiceTest.java new file mode 100644 index 0000000..16d3417 --- /dev/null +++ b/back/src/test/java/com/back/global/storage/LocalStorageServiceTest.java @@ -0,0 +1,278 @@ +package com.back.global.storage; + +import com.back.global.ai.config.ImageAiConfig; +import com.back.global.exception.ApiException; +import com.back.global.exception.ErrorCode; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Base64; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.BDDMockito.*; + +/** + * LocalStorageService 단위 테스트. + * 로컬 파일 시스템 기반 이미지 저장 로직을 검증합니다. + */ +@ExtendWith(MockitoExtension.class) +@DisplayName("LocalStorageService 단위 테스트") +class LocalStorageServiceTest { + + @Mock + private ImageAiConfig imageAiConfig; + + @InjectMocks + private LocalStorageService localStorageService; + + private static final String TEST_STORAGE_PATH = "./test-uploads/images"; + private static final String TEST_BASE_URL = "http://localhost:8080/test-images"; + private static final String VALID_BASE64 = Base64.getEncoder().encodeToString("test image data".getBytes()); + + @BeforeEach + void setUp() { + // 테스트용 설정 모킹 (lenient로 설정하여 불필요한 stubbing 경고 방지) + lenient().when(imageAiConfig.getLocalStoragePath()).thenReturn(TEST_STORAGE_PATH); + lenient().when(imageAiConfig.getLocalBaseUrl()).thenReturn(TEST_BASE_URL); + } + + @AfterEach + void tearDown() throws IOException { + // 테스트 파일 및 디렉토리 정리 + Path testDir = Paths.get(TEST_STORAGE_PATH); + if (Files.exists(testDir)) { + Files.walk(testDir) + .sorted((a, b) -> -a.compareTo(b)) // 파일 먼저 삭제, 디렉토리 나중에 삭제 + .forEach(path -> { + try { + Files.deleteIfExists(path); + } catch (IOException e) { + // 테스트 정리 중 에러는 무시 + } + }); + } + } + + @Nested + @DisplayName("로컬 업로드") + class UploadBase64ImageTests { + + @Test + @DisplayName("성공 - Base64 이미지 업로드") + void uploadBase64Image_성공_Base64_이미지_업로드() throws ExecutionException, InterruptedException, IOException { + // Given + String base64Data = VALID_BASE64; + + // When + CompletableFuture resultFuture = localStorageService.uploadBase64Image(base64Data); + String resultUrl = resultFuture.get(); + + // Then + assertThat(resultUrl).isNotNull(); + assertThat(resultUrl).startsWith(TEST_BASE_URL); + assertThat(resultUrl).contains("scenario-"); + assertThat(resultUrl).endsWith(".jpeg"); + + // 실제 파일이 생성되었는지 확인 + String fileName = resultUrl.substring(resultUrl.lastIndexOf('/') + 1); + Path savedFile = Paths.get(TEST_STORAGE_PATH).resolve(fileName); + assertThat(Files.exists(savedFile)).isTrue(); + assertThat(Files.size(savedFile)).isGreaterThan(0); + } + + @Test + @DisplayName("성공 - 디렉토리 자동 생성") + void uploadBase64Image_성공_디렉토리_자동생성() throws ExecutionException, InterruptedException, IOException { + // Given + String base64Data = VALID_BASE64; + Path testDir = Paths.get(TEST_STORAGE_PATH); + + // 디렉토리가 없는 상태 확인 + if (Files.exists(testDir)) { + Files.walk(testDir) + .sorted((a, b) -> -a.compareTo(b)) + .forEach(path -> { + try { + Files.deleteIfExists(path); + } catch (IOException e) { + // ignore + } + }); + } + assertThat(Files.exists(testDir)).isFalse(); + + // When + CompletableFuture resultFuture = localStorageService.uploadBase64Image(base64Data); + String resultUrl = resultFuture.get(); + + // Then + assertThat(resultUrl).isNotNull(); + assertThat(Files.exists(testDir)).isTrue(); // 디렉토리 자동 생성 확인 + } + + @Test + @DisplayName("실패 - 잘못된 Base64 데이터") + void uploadBase64Image_실패_잘못된_Base64_데이터() { + // Given + String invalidBase64 = "this-is-not-base64!!!"; + + // When + CompletableFuture resultFuture = localStorageService.uploadBase64Image(invalidBase64); + + // Then + assertThatThrownBy(resultFuture::get) + .hasCauseInstanceOf(ApiException.class) + .cause() + .hasFieldOrPropertyWithValue("errorCode", ErrorCode.STORAGE_INVALID_FILE); + } + + @Test + @DisplayName("성공 - UUID 파일명 중복 없음") + void uploadBase64Image_성공_UUID_파일명_중복없음() throws ExecutionException, InterruptedException { + // Given + String base64Data = VALID_BASE64; + + // When - 3번 업로드 + String url1 = localStorageService.uploadBase64Image(base64Data).get(); + String url2 = localStorageService.uploadBase64Image(base64Data).get(); + String url3 = localStorageService.uploadBase64Image(base64Data).get(); + + // Then - 모두 다른 URL이어야 함 (UUID 덕분) + assertThat(url1).isNotEqualTo(url2); + assertThat(url2).isNotEqualTo(url3); + assertThat(url1).isNotEqualTo(url3); + } + } + + @Nested + @DisplayName("로컬 삭제") + class DeleteImageTests { + + @Test + @DisplayName("성공 - 이미지 삭제") + void deleteImage_성공_이미지_삭제() throws ExecutionException, InterruptedException { + // Given - 먼저 이미지 업로드 + String base64Data = VALID_BASE64; + String uploadedUrl = localStorageService.uploadBase64Image(base64Data).get(); + + // 파일 존재 확인 + String fileName = uploadedUrl.substring(uploadedUrl.lastIndexOf('/') + 1); + Path filePath = Paths.get(TEST_STORAGE_PATH).resolve(fileName); + assertThat(Files.exists(filePath)).isTrue(); + + // When - 삭제 + CompletableFuture deleteFuture = localStorageService.deleteImage(uploadedUrl); + deleteFuture.get(); + + // Then - 파일이 삭제되었는지 확인 + assertThat(Files.exists(filePath)).isFalse(); + } + + @Test + @DisplayName("성공 - 파일 없음 (경고 로그만)") + void deleteImage_성공_파일없음_경고로그만() { + // Given - 존재하지 않는 URL + String nonExistentUrl = TEST_BASE_URL + "/non-existent-file.jpeg"; + + // When & Then - 예외 발생 안 함 (경고 로그만) + assertThatCode(() -> localStorageService.deleteImage(nonExistentUrl).get()) + .doesNotThrowAnyException(); + } + + @Test + @DisplayName("실패 - null URL") + void deleteImage_실패_null_URL() { + // Given + String nullUrl = null; + + // When + CompletableFuture deleteFuture = localStorageService.deleteImage(nullUrl); + + // Then + // extractFileNameFromUrl에서 STORAGE_INVALID_FILE을 던지지만 + // deleteImage의 catch (Exception e)에서 STORAGE_DELETE_FAILED로 래핑됨 + assertThatThrownBy(deleteFuture::get) + .hasCauseInstanceOf(ApiException.class) + .cause() + .hasFieldOrPropertyWithValue("errorCode", ErrorCode.STORAGE_DELETE_FAILED); + } + + @Test + @DisplayName("실패 - 빈 URL") + void deleteImage_실패_빈_URL() { + // Given + String emptyUrl = ""; + + // When + CompletableFuture deleteFuture = localStorageService.deleteImage(emptyUrl); + + // Then + // extractFileNameFromUrl에서 STORAGE_INVALID_FILE을 던지지만 + // deleteImage의 catch (Exception e)에서 STORAGE_DELETE_FAILED로 래핑됨 + assertThatThrownBy(deleteFuture::get) + .hasCauseInstanceOf(ApiException.class) + .cause() + .hasFieldOrPropertyWithValue("errorCode", ErrorCode.STORAGE_DELETE_FAILED); + } + } + + @Nested + @DisplayName("URL 파싱") + class ExtractFileNameTests { + + @Test + @DisplayName("성공 - URL에서 파일명 추출") + void extractFileName_성공_URL에서_파일명_추출() throws ExecutionException, InterruptedException { + // Given + String base64Data = VALID_BASE64; + String uploadedUrl = localStorageService.uploadBase64Image(base64Data).get(); + + // When - deleteImage 내부에서 extractFileNameFromUrl 호출됨 + String expectedFileName = uploadedUrl.substring(uploadedUrl.lastIndexOf('/') + 1); + + // Then + assertThat(expectedFileName).startsWith("scenario-"); + assertThat(expectedFileName).endsWith(".jpeg"); + assertThat(expectedFileName).contains("-"); // UUID 구분자 + } + + @Test + @DisplayName("성공 - 복잡한 URL 파싱") + void extractFileName_성공_복잡한_URL_파싱() { + // Given - 복잡한 URL + String complexUrl = "http://localhost:8080/api/v1/images/scenario-123e4567-e89b-12d3-a456-426614174000.jpeg"; + + // When & Then - deleteImage가 정상적으로 파일명 추출하는지 확인 (예외 안 남) + assertThatCode(() -> localStorageService.deleteImage(complexUrl).get()) + .doesNotThrowAnyException(); // 파일 없어도 경고만 출력 + } + } + + @Nested + @DisplayName("스토리지 타입") + class GetStorageTypeTests { + + @Test + @DisplayName("성공 - 스토리지 타입 확인") + void getStorageType_성공_스토리지_타입_확인() { + // When + String storageType = localStorageService.getStorageType(); + + // Then + assertThat(storageType).isEqualTo("local"); + } + } +} diff --git a/back/src/test/java/com/back/global/storage/S3StorageServiceTest.java b/back/src/test/java/com/back/global/storage/S3StorageServiceTest.java new file mode 100644 index 0000000..48a1723 --- /dev/null +++ b/back/src/test/java/com/back/global/storage/S3StorageServiceTest.java @@ -0,0 +1,321 @@ +package com.back.global.storage; + +import com.back.global.ai.config.ImageAiConfig; +import com.back.global.exception.ApiException; +import com.back.global.exception.ErrorCode; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.awscore.exception.AwsErrorDetails; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; + +import java.util.Base64; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import software.amazon.awssdk.services.s3.model.DeleteObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.*; +import static org.mockito.Mockito.*; + +/** + * S3StorageService 단위 테스트. + * AWS S3 기반 이미지 저장 로직을 검증합니다. + */ +@ExtendWith(MockitoExtension.class) +@DisplayName("S3StorageService 단위 테스트") +class S3StorageServiceTest { + + @Mock + private S3Client s3Client; + + @Mock + private ImageAiConfig imageAiConfig; + + @InjectMocks + private S3StorageService s3StorageService; + + private static final String TEST_BUCKET_NAME = "test-bucket"; + private static final String TEST_REGION = "ap-northeast-2"; + private static final String VALID_BASE64 = Base64.getEncoder().encodeToString("test image data".getBytes()); + + @BeforeEach + void setUp() { + // 테스트용 설정 모킹 (lenient로 설정) + lenient().when(imageAiConfig.getS3BucketName()).thenReturn(TEST_BUCKET_NAME); + lenient().when(imageAiConfig.getS3Region()).thenReturn(TEST_REGION); + } + + @Nested + @DisplayName("S3 업로드") + class UploadBase64ImageTests { + + @Test + @DisplayName("성공 - Base64 이미지 S3 업로드") + void uploadBase64Image_성공_Base64_이미지_S3_업로드() throws ExecutionException, InterruptedException { + // Given + String base64Data = VALID_BASE64; + + // S3Client.putObject() 모킹 (PutObjectResponse 반환) + PutObjectResponse mockResponse = PutObjectResponse.builder().build(); + given(s3Client.putObject(any(PutObjectRequest.class), any(RequestBody.class))) + .willReturn(mockResponse); + + // When + CompletableFuture resultFuture = s3StorageService.uploadBase64Image(base64Data); + String resultUrl = resultFuture.get(); + + // Then + assertThat(resultUrl).isNotNull(); + assertThat(resultUrl).startsWith("https://" + TEST_BUCKET_NAME + ".s3." + TEST_REGION + ".amazonaws.com/"); + assertThat(resultUrl).contains("scenario-"); + assertThat(resultUrl).endsWith(".jpeg"); + + // S3Client 호출 검증 + verify(s3Client, times(1)).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + } + + @Test + @DisplayName("실패 - 잘못된 Base64 데이터") + void uploadBase64Image_실패_잘못된_Base64_데이터() { + // Given + String invalidBase64 = "this-is-not-base64!!!"; + + // When + CompletableFuture resultFuture = s3StorageService.uploadBase64Image(invalidBase64); + + // Then + assertThatThrownBy(resultFuture::get) + .hasCauseInstanceOf(ApiException.class) + .cause() + .hasFieldOrPropertyWithValue("errorCode", ErrorCode.STORAGE_INVALID_FILE); + + // S3Client는 호출되지 않아야 함 + verify(s3Client, never()).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + } + + @Test + @DisplayName("실패 - S3 서비스 에러") + void uploadBase64Image_실패_S3_서비스_에러() { + // Given + String base64Data = VALID_BASE64; + + // S3Exception 모킹 + AwsErrorDetails errorDetails = AwsErrorDetails.builder() + .errorMessage("Access Denied") + .build(); + S3Exception s3Exception = (S3Exception) S3Exception.builder() + .awsErrorDetails(errorDetails) + .message("S3 Error") + .build(); + + doThrow(s3Exception).when(s3Client).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + + // When + CompletableFuture resultFuture = s3StorageService.uploadBase64Image(base64Data); + + // Then + assertThatThrownBy(resultFuture::get) + .hasCauseInstanceOf(ApiException.class) + .cause() + .hasFieldOrPropertyWithValue("errorCode", ErrorCode.S3_CONNECTION_FAILED); + + verify(s3Client, times(1)).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + } + + @Test + @DisplayName("성공 - UUID 파일명 중복 없음") + void uploadBase64Image_성공_UUID_파일명_중복없음() throws ExecutionException, InterruptedException { + // Given + String base64Data = VALID_BASE64; + PutObjectResponse mockResponse = PutObjectResponse.builder().build(); + given(s3Client.putObject(any(PutObjectRequest.class), any(RequestBody.class))) + .willReturn(mockResponse); + + // When - 3번 업로드 + String url1 = s3StorageService.uploadBase64Image(base64Data).get(); + String url2 = s3StorageService.uploadBase64Image(base64Data).get(); + String url3 = s3StorageService.uploadBase64Image(base64Data).get(); + + // Then - 모두 다른 URL이어야 함 (UUID 덕분) + assertThat(url1).isNotEqualTo(url2); + assertThat(url2).isNotEqualTo(url3); + assertThat(url1).isNotEqualTo(url3); + + // S3Client 3번 호출 확인 + verify(s3Client, times(3)).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + } + + @Test + @DisplayName("성공 - S3 URL 형식 확인") + void uploadBase64Image_성공_S3_URL_형식_확인() throws ExecutionException, InterruptedException { + // Given + String base64Data = VALID_BASE64; + PutObjectResponse mockResponse = PutObjectResponse.builder().build(); + given(s3Client.putObject(any(PutObjectRequest.class), any(RequestBody.class))) + .willReturn(mockResponse); + + // When + String resultUrl = s3StorageService.uploadBase64Image(base64Data).get(); + + // Then - S3 표준 URL 형식: https://{bucket}.s3.{region}.amazonaws.com/{key} + assertThat(resultUrl).matches("https://test-bucket\\.s3\\.ap-northeast-2\\.amazonaws\\.com/scenario-[a-f0-9\\-]+\\.jpeg"); + } + } + + @Nested + @DisplayName("S3 삭제") + class DeleteImageTests { + + @Test + @DisplayName("성공 - S3 이미지 삭제") + void deleteImage_성공_S3_이미지_삭제() throws ExecutionException, InterruptedException { + // Given + String s3Url = "https://test-bucket.s3.ap-northeast-2.amazonaws.com/scenario-test-uuid.jpeg"; + DeleteObjectResponse mockResponse = DeleteObjectResponse.builder().build(); + given(s3Client.deleteObject(any(DeleteObjectRequest.class))) + .willReturn(mockResponse); + + // When + CompletableFuture deleteFuture = s3StorageService.deleteImage(s3Url); + deleteFuture.get(); + + // Then + verify(s3Client, times(1)).deleteObject(any(DeleteObjectRequest.class)); + } + + @Test + @DisplayName("실패 - S3 서비스 에러") + void deleteImage_실패_S3_서비스_에러() { + // Given + String s3Url = "https://test-bucket.s3.ap-northeast-2.amazonaws.com/scenario-test-uuid.jpeg"; + + AwsErrorDetails errorDetails = AwsErrorDetails.builder() + .errorMessage("NoSuchKey") + .build(); + S3Exception s3Exception = (S3Exception) S3Exception.builder() + .awsErrorDetails(errorDetails) + .message("S3 Error") + .build(); + + doThrow(s3Exception).when(s3Client).deleteObject(any(DeleteObjectRequest.class)); + + // When + CompletableFuture deleteFuture = s3StorageService.deleteImage(s3Url); + + // Then + assertThatThrownBy(deleteFuture::get) + .hasCauseInstanceOf(ApiException.class) + .cause() + .hasFieldOrPropertyWithValue("errorCode", ErrorCode.S3_CONNECTION_FAILED); + + verify(s3Client, times(1)).deleteObject(any(DeleteObjectRequest.class)); + } + + @Test + @DisplayName("실패 - null URL") + void deleteImage_실패_null_URL() { + // Given + String nullUrl = null; + + // When + CompletableFuture deleteFuture = s3StorageService.deleteImage(nullUrl); + + // Then + assertThatThrownBy(deleteFuture::get) + .hasCauseInstanceOf(ApiException.class) + .cause() + .hasFieldOrPropertyWithValue("errorCode", ErrorCode.STORAGE_DELETE_FAILED); + + // S3Client는 호출되지 않아야 함 + verify(s3Client, never()).deleteObject(any(DeleteObjectRequest.class)); + } + + @Test + @DisplayName("실패 - 빈 URL") + void deleteImage_실패_빈_URL() { + // Given + String emptyUrl = ""; + + // When + CompletableFuture deleteFuture = s3StorageService.deleteImage(emptyUrl); + + // Then + assertThatThrownBy(deleteFuture::get) + .hasCauseInstanceOf(ApiException.class) + .cause() + .hasFieldOrPropertyWithValue("errorCode", ErrorCode.STORAGE_DELETE_FAILED); + + // S3Client는 호출되지 않아야 함 + verify(s3Client, never()).deleteObject(any(DeleteObjectRequest.class)); + } + } + + @Nested + @DisplayName("URL 파싱") + class ExtractFileNameTests { + + @Test + @DisplayName("성공 - S3 URL에서 파일명 추출") + void extractFileName_성공_S3_URL에서_파일명_추출() throws ExecutionException, InterruptedException { + // Given + String s3Url = "https://test-bucket.s3.ap-northeast-2.amazonaws.com/scenario-abc-123.jpeg"; + DeleteObjectResponse mockResponse = DeleteObjectResponse.builder().build(); + given(s3Client.deleteObject(any(DeleteObjectRequest.class))) + .willReturn(mockResponse); + + // When - deleteImage 내부에서 extractFileNameFromUrl 호출됨 + s3StorageService.deleteImage(s3Url).get(); + + // Then - 정상적으로 파일명 추출 및 삭제 요청 성공 + verify(s3Client, times(1)).deleteObject(argThat((DeleteObjectRequest request) -> + request.key().equals("scenario-abc-123.jpeg") + )); + } + + @Test + @DisplayName("성공 - 복잡한 S3 URL 파싱") + void extractFileName_성공_복잡한_S3_URL_파싱() throws ExecutionException, InterruptedException { + // Given - 경로가 있는 복잡한 URL + String complexUrl = "https://test-bucket.s3.ap-northeast-2.amazonaws.com/images/2024/scenario-test.jpeg"; + DeleteObjectResponse mockResponse = DeleteObjectResponse.builder().build(); + given(s3Client.deleteObject(any(DeleteObjectRequest.class))) + .willReturn(mockResponse); + + // When + s3StorageService.deleteImage(complexUrl).get(); + + // Then - 마지막 부분만 추출 + verify(s3Client, times(1)).deleteObject(argThat((DeleteObjectRequest request) -> + request.key().equals("scenario-test.jpeg") + )); + } + } + + @Nested + @DisplayName("스토리지 타입") + class GetStorageTypeTests { + + @Test + @DisplayName("성공 - 스토리지 타입 확인") + void getStorageType_성공_스토리지_타입_확인() { + // When + String storageType = s3StorageService.getStorageType(); + + // Then + assertThat(storageType).isEqualTo("s3"); + } + } +}