diff --git a/back/.gitignore b/back/.gitignore index b562344..30a9cd6 100644 --- a/back/.gitignore +++ b/back/.gitignore @@ -37,7 +37,7 @@ out/ .vscode/ ### Environment Variables ### -.env +../.env # 테스트 이미지 경로 (LocalStorageServiceTest) test-uploads/ diff --git a/back/build.gradle.kts b/back/build.gradle.kts index dbec77d..802a6ba 100644 --- a/back/build.gradle.kts +++ b/back/build.gradle.kts @@ -47,6 +47,12 @@ dependencies { testImplementation("org.springframework.boot:spring-boot-starter-test") testImplementation("org.springframework.security:spring-security-test") testRuntimeOnly("org.junit.platform:junit-platform-launcher") + testImplementation("org.testcontainers:junit-jupiter") + testImplementation("org.testcontainers:postgresql") + testImplementation("org.springframework.boot:spring-boot-testcontainers") + implementation("org.testcontainers:testcontainers") + implementation("org.testcontainers:jdbc") + implementation("org.testcontainers:postgresql") // Swagger implementation("org.springdoc:springdoc-openapi-starter-webmvc-ui:2.8.9") @@ -63,6 +69,7 @@ dependencies { // Database runtimeOnly("com.h2database:h2") runtimeOnly("org.postgresql:postgresql") + implementation("org.postgresql:postgresql") // Migration implementation("org.flywaydb:flyway-core:11.11.2") @@ -78,6 +85,8 @@ dependencies { // AI Services - WebFlux for non-blocking HTTP clients implementation("org.springframework.boot:spring-boot-starter-webflux") implementation("com.fasterxml.jackson.module:jackson-module-kotlin") + implementation("io.netty:netty-tcnative-boringssl-static:2.0.65.Final") + // AWS SDK for S3 implementation("software.amazon.awssdk:s3:2.20.+") diff --git a/back/src/main/java/com/back/domain/scenario/service/ScenarioService.java b/back/src/main/java/com/back/domain/scenario/service/ScenarioService.java index fe02b3b..b75f429 100644 --- a/back/src/main/java/com/back/domain/scenario/service/ScenarioService.java +++ b/back/src/main/java/com/back/domain/scenario/service/ScenarioService.java @@ -30,6 +30,7 @@ import org.springframework.transaction.annotation.Transactional; import java.time.LocalDateTime; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -163,6 +164,39 @@ protected Long createScenarioInTransaction( // lastDecision 처리 (필요 시) if (lastDecision != null) { decisionFlowService.createDecisionNodeNext(lastDecision); + List ordered = decisionNodeRepository.findByDecisionLine_IdOrderByAgeYearAscIdAsc(decisionLine.getId()); + DecisionNode parent = ordered.isEmpty() ? null : ordered.get(ordered.size() - 1); + + // 베이스 라인의 tail BaseNode 해석(“결말” 우선, 없으면 최대 age) + BaseLine baseLine = decisionLine.getBaseLine(); + List baseNodes = baseLine.getBaseNodes(); + BaseNode tailBase = baseNodes.stream() + .filter(b -> { + String s = b.getSituation() == null ? "" : b.getSituation(); + String d = b.getDecision() == null ? "" : b.getDecision(); + return s.contains("결말") || d.contains("결말"); + }) + .max(Comparator.comparingInt(BaseNode::getAgeYear).thenComparingLong(BaseNode::getId)) + .orElseGet(() -> baseNodes.stream() + .max(Comparator.comparingInt(BaseNode::getAgeYear).thenComparingLong(BaseNode::getId)) + .orElseThrow(() -> new ApiException(ErrorCode.INVALID_INPUT_VALUE, "tail base not found")) + ); + + // 엔티티 빌더로 ‘결말’ 결정노드 저장(테일과 동일 age) + DecisionNode ending = DecisionNode.builder() + .user(decisionLine.getUser()) + .nodeKind(NodeType.DECISION) + .decisionLine(decisionLine) + .baseNode(tailBase) + .parent(parent) + .category(tailBase.getCategory()) + .situation("결말") + .decision("결말") + .ageYear(tailBase.getAgeYear()) + .background(tailBase.getSituation() == null ? "" : tailBase.getSituation()) + .build(); + + decisionNodeRepository.save(ending); } // DecisionLine 완료 처리 diff --git a/back/src/main/java/com/back/domain/search/entity/NodeSnippet.java b/back/src/main/java/com/back/domain/search/entity/NodeSnippet.java new file mode 100644 index 0000000..e12959f --- /dev/null +++ b/back/src/main/java/com/back/domain/search/entity/NodeSnippet.java @@ -0,0 +1,39 @@ +/* + * 이 파일은 RAG 검색용 스니펫 엔티티를 정의한다. + * 라인/나이/카테고리/텍스트/임베딩을 저장하며 pgvector 컬럼을 float[]로 매핑한다. + */ +package com.back.domain.search.entity; + +import com.back.infra.pgvector.PgVectorConverter; +import jakarta.persistence.*; +import lombok.*; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.type.SqlTypes; + +@Entity +@Table(name = "node_snippet") +@Getter @Setter +@NoArgsConstructor @AllArgsConstructor @Builder +public class NodeSnippet { + + @Id @GeneratedValue(strategy = GenerationType.IDENTITY) + private Long id; + + @Column(name = "line_id", nullable = false) + private Long lineId; + + @Column(name = "age_year", nullable = false) + private Integer ageYear; + + private String category; + + @Column(name = "text", nullable = false, columnDefinition = "text") + private String text; + + @JdbcTypeCode(SqlTypes.OTHER) + @Convert(converter = PgVectorConverter.class) + @Column(name = "embedding", nullable = false, columnDefinition = "vector(768)") + private float[] embedding; + + +} diff --git a/back/src/main/java/com/back/domain/search/repository/NodeSnippetRepository.java b/back/src/main/java/com/back/domain/search/repository/NodeSnippetRepository.java new file mode 100644 index 0000000..c883c6e --- /dev/null +++ b/back/src/main/java/com/back/domain/search/repository/NodeSnippetRepository.java @@ -0,0 +1,45 @@ +/* + * 이 파일은 라인/나이 윈도우로 후보를 좁힌 뒤 pgvector 유사도로 정렬해 topK를 반환하는 네이티브 쿼리를 제공한다. + */ +package com.back.domain.search.repository; + +import com.back.domain.search.entity.NodeSnippet; +import org.springframework.data.jpa.repository.*; +import org.springframework.data.repository.query.Param; + +import java.util.List; + +public interface NodeSnippetRepository extends JpaRepository { + + // 라인/나이 윈도우 필터 + pgvector 유사도(<=>) 정렬로 상위 K를 조회한다. + @Query(value = """ + SELECT * FROM node_snippet + WHERE line_id = :lineId + AND age_year BETWEEN :minAge AND :maxAge + ORDER BY embedding <=> CAST(:q AS vector) + LIMIT :k + """, nativeQuery = true) + List searchTopKByLineAndAgeWindow( + @Param("lineId") Long lineId, + @Param("minAge") Integer minAge, + @Param("maxAge") Integer maxAge, + @Param("q") String vectorLiteral, + @Param("k") int k + ); + + // 텍스트만(가볍게) 가져오기 — 네트워크·파싱 비용 급감 + @Query(value = """ + SELECT text FROM node_snippet + WHERE line_id = :lineId + AND age_year BETWEEN :minAge AND :maxAge + ORDER BY embedding <=> CAST(:q AS vector) + LIMIT :k + """, nativeQuery = true) + List searchTopKTextByLineAndAgeWindow( + @Param("lineId") Long lineId, + @Param("minAge") Integer minAge, + @Param("maxAge") Integer maxAge, + @Param("q") String vectorLiteral, + @Param("k") int k + ); +} diff --git a/back/src/main/java/com/back/global/ai/client/text/GeminiJsonTextClient.java b/back/src/main/java/com/back/global/ai/client/text/GeminiJsonTextClient.java new file mode 100644 index 0000000..e550435 --- /dev/null +++ b/back/src/main/java/com/back/global/ai/client/text/GeminiJsonTextClient.java @@ -0,0 +1,337 @@ +/* + * 2.0 전용 JSON 강제 & 안전화(최소 변경) + VectorResponse 매핑 + */ +package com.back.global.ai.client.text; + +import com.back.global.ai.config.TextAiConfig; +import com.back.global.ai.dto.AiRequest; +import com.back.global.ai.dto.gemini.GeminiResponse; +import com.back.global.ai.dto.gemini.VectorResponse; +import com.back.global.ai.exception.AiApiException; +import com.back.global.ai.exception.AiParsingException; +import com.back.global.ai.exception.AiTimeoutException; +import com.back.global.exception.ErrorCode; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.annotation.Primary; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Component; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +import java.io.IOException; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeoutException; + +@Component("gemini20JsonClient") +@Primary +@Slf4j +public class GeminiJsonTextClient implements TextAiClient { + + private final WebClient webClient; + private final TextAiConfig textAiConfig; + private final ObjectMapper objectMapper; // ← 추가 + + public GeminiJsonTextClient(@Qualifier("geminiWebClient") WebClient webClient, + TextAiConfig textAiConfig, + ObjectMapper objectMapper) { // ← 추가 + this.webClient = webClient; + this.textAiConfig = textAiConfig; + this.objectMapper = objectMapper; // ← 추가 + } + + @Override + public CompletableFuture generateText(String prompt) { + return generateText(new AiRequest(prompt, Map.of())); + } + + @Override + public CompletableFuture generateText(AiRequest aiRequest) { + if (aiRequest == null || aiRequest.prompt() == null) { + return CompletableFuture.failedFuture(new AiParsingException("Prompt is null")); + } + + int ctxLimit = Math.max(4096, safeInt(textAiConfig.getMaxContextTokens(), 8192)); + int inTokens = estimateTokens(aiRequest.prompt()); + int userMaxOut = Math.max(1, aiRequest.maxTokens()); + int safety = 128; + int minOut = 256; + int roomForOut = Math.max(minOut, ctxLimit - inTokens - safety); + int allowedOut = Math.max(minOut, Math.min(roomForOut, userMaxOut)); + + String fittedPrompt = fitPromptToContext(aiRequest.prompt(), ctxLimit, allowedOut, safety); + + Map body = createGeminiRequest(aiRequest.parameters(), fittedPrompt, allowedOut); + Mono call = invoke(body); + + log.debug("[Gemini-2.0] ctxLimit={}, in≈{}, allowedOut={}, promptChars={}", + ctxLimit, inTokens, allowedOut, aiRequest.prompt().length()); + + return call + .map(resp -> tryExtract(resp, false)) + .onErrorResume(e -> { + if (e instanceof AiParsingException apx && + apx.getMessage() != null && + apx.getMessage().contains("MAX_TOKENS")) { + + log.warn("[Gemini-2.0] fallback retry due to MAX_TOKENS: {}", apx.getMessage()); + + int fallbackOut = Math.max(256, Math.min(userMaxOut * 2, 512)); + String shortPrompt = + """ + 너는 JSON 출력기야. 아래 정보를 오직 JSON 객체 1개로만 반환해. + 필드명: situation, recommendedOption (각 값은 한국어 1문장, 20자 이내) + 마크다운/코드펜스/설명/추가텍스트 금지. JSON만! + --- + """ + safeHead(fittedPrompt, 800); + + Map fallbackBody = createGeminiRequest(aiRequest.parameters(), shortPrompt, fallbackOut); + return invoke(fallbackBody).map(resp -> tryExtract(resp, true)); + } + return Mono.error(e); + }) + .timeout(Duration.ofSeconds(textAiConfig.getTimeoutSeconds())) + .onErrorMap(TimeoutException.class, + t -> new AiTimeoutException("Gemini request timeout after " + + textAiConfig.getTimeoutSeconds() + "s")) + .retryWhen( + Retry.backoff(textAiConfig.getMaxRetries(), + Duration.ofSeconds(textAiConfig.getRetryDelaySeconds())) + .filter(this::isTransient) + ) + .doOnError(e -> log.error("Gemini API call failed: {}", safeTruncate(e.toString(), 2000))) + .toFuture(); + } + + /** 편의 메서드: 바로 VectorResponse DTO로 받기 */ + public CompletableFuture generateVector(AiRequest aiRequest) { + return generateText(aiRequest).thenApply(this::toVectorResponse); + } + + // --- 내부 호출/파싱 --- + + private boolean isTransient(Throwable ex) { + if (ex instanceof IOException || ex instanceof TimeoutException) return true; + if (ex instanceof WebClientResponseException w && w.getStatusCode().is5xxServerError()) return true; + return false; + } + + private Mono invoke(Map body) { + return webClient.post() + .uri("/v1beta/models/{model}:generateContent", textAiConfig.getModel20()) + .header("x-goog-api-key", textAiConfig.getApiKey()) + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(body) + .retrieve() + .onStatus(HttpStatusCode::isError, this::handleErrorResponse) + .bodyToMono(GeminiResponse.class); + } + + // ===== 최소 변경 JSON 강제 ===== + private Map createGeminiRequest(Map params, String prompt, int outTokens) { + Map gen = new HashMap<>(); + // 2.0: 빠른/일관 응답 위주 + gen.put("temperature", 0.2); + gen.put("topK", 1); + gen.put("topP", 0.9); + gen.put("candidateCount", 1); + gen.put("maxOutputTokens", 120); + + // JSON 강제 (camel/snake 모두 세팅) + gen.put("response_mime_type", "application/json"); + gen.put("responseMimeType", "application/json"); + + Map responseSchema = Map.of( + "type", "OBJECT", + "properties", Map.of( + "situation", Map.of("type", "STRING"), + "recommendedOption", Map.of("type", "STRING") + ), + "required", List.of("situation", "recommendedOption") + ); + gen.put("response_schema", responseSchema); + gen.put("responseSchema", responseSchema); + + // 외부 파라미터(화이트리스트) — camel/snake 모두 허용 + if (params != null) { + copyIfPresent(params, gen, "temperature"); + copyIfPresent(params, gen, "topK"); + copyIfPresent(params, gen, "topP"); + copyIfPresent(params, gen, "maxOutputTokens"); + copyIfPresent(params, gen, "candidateCount"); + + // MIME / 스키마 키 양쪽 지원 + copyIfPresent(params, gen, "responseMimeType"); + copyIfPresent(params, gen, "response_mime_type"); + copyIfPresent(params, gen, "responseSchema"); + copyIfPresent(params, gen, "response_schema"); + // stopSequences는 조기 컷 리스크로 미사용 + } + + Map userContent = Map.of( + "role", "user", + "parts", List.of(Map.of("text", prompt)) + ); + + // JSON-only 시스템 규율 + Map systemInstruction = Map.of( + "parts", List.of(Map.of("text", + "You are a JSON emitter. Output ONLY one compact JSON object that matches the schema. " + + "No markdown, no code fences, no explanations. Keys: situation, recommendedOption.")) + ); + + Map body = new HashMap<>(); + body.put("systemInstruction", systemInstruction); + body.put("contents", List.of(userContent)); + body.put("generationConfig", gen); + body.put("toolConfig", Map.of("functionCallingConfig", Map.of("mode", "NONE"))); + return body; + } + // ========================== + + private void copyIfPresent(Map src, Map dst, String key) { + Object v = src.get(key); + if (v != null) dst.put(key, v); + } + + private String extractContent(GeminiResponse response) { + if (response == null || response.candidates() == null || response.candidates().isEmpty()) { + log.warn("[Gemini] empty candidates: body=null/empty"); + throw new AiParsingException("No candidates in Gemini response"); + } + var c = response.candidates().get(0); + var finish = c.finishReason(); + if ("SAFETY".equalsIgnoreCase(finish)) { + log.warn("[Gemini] content blocked by safety filters. finishReason=SAFETY"); + throw new AiParsingException("Content blocked by safety filters"); + } + if (c.content() == null || c.content().parts() == null || c.content().parts().isEmpty()) { + String reason = finish != null ? " (" + finish + ")" : ""; + log.warn("[Gemini] no parts in candidate content. finishReason={}", finish); + throw new AiParsingException("No parts in candidate content" + reason); + } + String text = c.content().parts().get(0).text(); + if (text == null || text.isBlank()) { + String reason = finish != null ? " (" + finish + ")" : ""; + log.warn("[Gemini] blank text part. finishReason={}", finish); + throw new AiParsingException("Blank text part" + reason); + } + return sanitizeJsonText(text); + } + + private String tryExtract(GeminiResponse response, boolean fallback) { + try { + return extractContent(response); + } catch (AiParsingException e) { + String finish = null; + if (response != null && response.candidates() != null && !response.candidates().isEmpty()) { + finish = response.candidates().get(0).finishReason(); + } + if ("MAX_TOKENS".equalsIgnoreCase(finish)) { + String msg = "No parts in candidate content (MAX_TOKENS)"; + if (fallback) throw new AiParsingException(msg); + throw new AiParsingException(msg); + } + throw e; + } + } + + private Mono handleErrorResponse(ClientResponse response) { + return response.bodyToMono(String.class) + .map(errorBody -> { + log.warn("[Gemini] HTTP error: status={}, body={}", + response.statusCode(), safeTruncate(errorBody, 2000)); + return new AiApiException( + ErrorCode.AI_SERVICE_UNAVAILABLE, + "Gemini API call failed: " + response.statusCode() + ); + }); + } + + // --- utils --- + + private int estimateTokens(String s) { + if (s == null || s.isEmpty()) return 0; + int len = s.length(); + int han = 0; + for (int i = 0; i < Math.min(len, 4000); i++) { + char ch = s.charAt(i); + if (Character.UnicodeBlock.of(ch) == Character.UnicodeBlock.HANGUL_SYLLABLES + || Character.UnicodeBlock.of(ch) == Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS) { + han++; + } + } + double hanRatio = (len == 0) ? 0 : (double) han / Math.min(len, 4000); + int div = (hanRatio > 0.3) ? 2 : 4; + return Math.max(1, len / div); + } + + private String fitPromptToContext(String prompt, int ctxLimit, int outTokens, int safety) { + int in = estimateTokens(prompt); + int limitIn = Math.max(256, ctxLimit - outTokens - safety); + if (in <= limitIn) return prompt; + + int targetChars = Math.max(200, (int) (prompt.length() * (limitIn / (double) in))); + int head = Math.min(prompt.length(), (int) (targetChars * 0.7)); + int tail = Math.min(prompt.length() - head, (int) (targetChars * 0.3)); + return prompt.substring(0, head) + "\n...\n" + prompt.substring(prompt.length() - tail); + } + + private static String safeHead(String s, int maxChars) { + if (s == null) return ""; + if (s.length() <= maxChars) return s; + return s.substring(0, maxChars) + "\n..."; + } + + private static int safeInt(Integer v, int defVal) { + return (v == null || v <= 0) ? defVal : v; + } + + private static String safeTruncate(String s, int max) { + return (s == null || s.length() <= max) ? s : s.substring(0, max) + "..."; + } + + // ```json ... ``` 래핑 제거 및 앞뒤 잡설 제거: 첫 완전한 JSON 객체만 반환 + private static String sanitizeJsonText(String raw) { + String s = raw.trim(); + if (s.startsWith("```")) { + int first = s.indexOf('\n'); + int last = s.lastIndexOf("```"); + if (first >= 0 && last > first) { + s = s.substring(first + 1, last).trim(); + } + } + int depth = 0, start = -1; + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if (c == '{') { if (depth == 0) start = i; depth++; } + else if (c == '}') { depth--; if (depth == 0 && start >= 0) { return s.substring(start, i + 1); } } + } + return s; + } + + // JSON 문자열 → VectorResponse + private VectorResponse toVectorResponse(String json) { + try { + String clean = sanitizeJsonText(json); + VectorResponse dto = objectMapper.readValue(clean, VectorResponse.class); + if (dto.situation() == null || dto.situation().isBlank() + || dto.recommendedOption() == null || dto.recommendedOption().isBlank()) { + throw new AiParsingException("Missing required fields in VectorResponse"); + } + return dto; + } catch (JsonProcessingException e) { + throw new AiParsingException("Failed to parse VectorResponse: " + e.getOriginalMessage()); + } + } +} diff --git a/back/src/main/java/com/back/global/ai/client/text/GeminiTextClient.java b/back/src/main/java/com/back/global/ai/client/text/GeminiTextClient.java index f8be86d..1e46115 100644 --- a/back/src/main/java/com/back/global/ai/client/text/GeminiTextClient.java +++ b/back/src/main/java/com/back/global/ai/client/text/GeminiTextClient.java @@ -25,7 +25,7 @@ * Gemini AI 텍스트 생성 클라이언트 * Google Gemini API를 통한 비동기 텍스트 생성, 재시도, 에러 처리를 담당합니다. */ -@Component +@Component("gemini25TextClient") @Slf4j public class GeminiTextClient implements TextAiClient { @@ -47,21 +47,21 @@ public CompletableFuture generateText(String prompt) { public CompletableFuture generateText(AiRequest aiRequest) { return webClient .post() - .uri("/v1beta/models/{model}:generateContent?key={apiKey}", - textAiConfig.getModel(), textAiConfig.getApiKey()) - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(createGeminiRequest(aiRequest.prompt(), aiRequest.maxTokens())) - .retrieve() - .onStatus(HttpStatusCode::isError, this::handleErrorResponse) - .bodyToMono(GeminiResponse.class) - .doOnNext(response -> log.debug("Gemini API response received: candidates={}, finishReason={}", - response.candidates().size(), - response.candidates().isEmpty() ? "N/A" : response.candidates().get(0).finishReason())) - .map(this::extractContent) - .timeout(Duration.ofSeconds(textAiConfig.getTimeoutSeconds())) - .retryWhen(Retry.backoff(textAiConfig.getMaxRetries(), Duration.ofSeconds(textAiConfig.getRetryDelaySeconds()))) - .doOnError(error -> log.error("Gemini API call failed: {}", error.getMessage(), error)) - .toFuture(); + .uri("/v1beta/models/{model}:generateContent", textAiConfig.getModel()) + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(createGeminiRequest(aiRequest.prompt(), aiRequest.maxTokens())) + .retrieve() + .onStatus(HttpStatusCode::isError, this::handleErrorResponse) + .bodyToMono(GeminiResponse.class) + .doOnNext(response -> log.debug("Gemini API response received: candidates={}, finishReason={}", + response.candidates().size(), + response.candidates().isEmpty() ? "N/A" : response.candidates().get(0).finishReason())) + .map(this::extractContent) + .timeout(Duration.ofSeconds(textAiConfig.getTimeoutSeconds())) + .retryWhen(Retry.backoff(textAiConfig.getMaxRetries(), + Duration.ofSeconds(textAiConfig.getRetryDelaySeconds()))) + .doOnError(error -> log.error("Gemini API call failed: {}", error.getMessage(), error)) + .toFuture(); } private Map createGeminiRequest(String prompt, int maxTokens) { diff --git a/back/src/main/java/com/back/global/ai/config/PgVectorContainerConfig.java b/back/src/main/java/com/back/global/ai/config/PgVectorContainerConfig.java new file mode 100644 index 0000000..c4b1987 --- /dev/null +++ b/back/src/main/java/com/back/global/ai/config/PgVectorContainerConfig.java @@ -0,0 +1,31 @@ +//package com.back.global.ai.config; +// +//import com.zaxxer.hikari.HikariDataSource; +//import org.springframework.context.annotation.*; +//import org.testcontainers.containers.PostgreSQLContainer; +// +//import javax.sql.DataSource; +// +//@Configuration +//@Profile("test") +//public class PgVectorContainerConfig { +// +// @Bean(initMethod = "start", destroyMethod = "stop") +// public PostgreSQLContainer pgContainer() { +// // 여기서 이미지 고정 +// PostgreSQLContainer c = new PostgreSQLContainer<>("pgvector/pgvector:pg16"); +// // 필요하면 파라미터 추가 +// // c.withReuse(true); +// return c; +// } +// +// @Bean +// public DataSource dataSource(PostgreSQLContainer pg) { +// HikariDataSource ds = new HikariDataSource(); +// ds.setJdbcUrl(pg.getJdbcUrl()); +// ds.setUsername(pg.getUsername()); +// ds.setPassword(pg.getPassword()); +// // 드라이버는 Hikari가 JDBC URL로 자동 판단 (org.postgresql.Driver) +// return ds; +// } +//} \ No newline at end of file diff --git a/back/src/main/java/com/back/global/ai/config/TextAiConfig.java b/back/src/main/java/com/back/global/ai/config/TextAiConfig.java index cfdaae8..bd9a8b3 100644 --- a/back/src/main/java/com/back/global/ai/config/TextAiConfig.java +++ b/back/src/main/java/com/back/global/ai/config/TextAiConfig.java @@ -1,40 +1,133 @@ package com.back.global.ai.config; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.netty.channel.ChannelOption; +import io.netty.handler.timeout.ReadTimeoutHandler; +import io.netty.handler.timeout.WriteTimeoutHandler; +import io.netty.resolver.DefaultAddressResolverGroup; +import jakarta.annotation.PostConstruct; import lombok.Data; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.client.reactive.ReactorClientHttpConnector; +import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; +import reactor.netty.http.HttpProtocol; +import reactor.netty.http.client.HttpClient; +import reactor.netty.resources.ConnectionProvider; + +import java.time.Duration; + -/** - * 텍스트 생성 AI 서비스 설정 클래스 - */ @Configuration @ConfigurationProperties(prefix = "ai.text.gemini") @EnableConfigurationProperties({ - SituationAiProperties.class, - BaseScenarioAiProperties.class, - DecisionScenarioAiProperties.class + SituationAiProperties.class, + BaseScenarioAiProperties.class, + DecisionScenarioAiProperties.class }) @Data public class TextAiConfig { String apiKey; String baseUrl = "https://generativelanguage.googleapis.com"; - String model = "gemini-2.5-flash"; // 추후 변경 가능 + String model = "gemini-2.5-flash"; + String model20 = "gemini-2.0-flash"; int timeoutSeconds = 30; int maxRetries = 3; - int retryDelaySeconds = 2; // 재시도 간격 (초) + int retryDelaySeconds = 2; + int maxConnections = 200; + + // ▼ 추가: 성능용 밀리초 단위 설정(필요시 yml로 노출) + private int inferenceTimeoutMillis = 1100; // 전체 호출 상한(p50 목표) + private int connectTimeoutMillis = 150; // TCP connect + private int readTimeoutMillis = 900; // 응답 수신 + private int writeTimeoutMillis = 300; // 요청 전송 + private int poolMaxConnections = 200; + private int pendingAcquireMaxCount = 1000; + private int pendingAcquireTimeoutSeconds = 2; + private int poolMaxIdleTimeSeconds = 300; + private int poolMaxLifeTimeSeconds = 900; + + private Integer maxContextTokens = 8192; + + @PostConstruct + void validateKey() { + if (apiKey == null || apiKey.isBlank()) { + throw new IllegalStateException("ai.text.gemini.api-key is missing"); + } + } - /** - * Gemini API 전용 WebClient Bean 생성 - */ @Bean("geminiWebClient") - public WebClient geminiWebClient() { + public WebClient geminiWebClient(ObjectMapper objectMapper) { + + // 커넥션 풀: idle/lifetime/evict 설정 추가 + ConnectionProvider pool = ConnectionProvider.builder("gemini-pool") + .maxConnections(maxConnections) // 200 + .pendingAcquireMaxCount(pendingAcquireMaxCount) // 1000 + .pendingAcquireTimeout(Duration.ofSeconds(pendingAcquireTimeoutSeconds)) // 2s + .maxIdleTime(Duration.ofMinutes(2)) // 유휴 연결 유지 + .maxLifeTime(Duration.ofMinutes(10)) // 오래된 연결 교체 + .evictInBackground(Duration.ofSeconds(30)) // 백그라운드 정리 + .lifo() // 핫 커넥션 우선 재사용 + .build(); + + HttpClient http = HttpClient.create(pool) + .compress(true) + .keepAlive(true) + .wiretap(false) + .secure() // TLS + ALPN + .protocol(HttpProtocol.H2) // ★ H2만 사용(멀티플렉싱 최대화) + .resolver(DefaultAddressResolverGroup.INSTANCE) // 네이티브 DNS 해소 + .responseTimeout(Duration.ofSeconds(timeoutSeconds)) + .doOnConnected(conn -> conn + .addHandlerLast(new ReadTimeoutHandler(timeoutSeconds)) + .addHandlerLast(new WriteTimeoutHandler(timeoutSeconds))) + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMillis); + + // ★ 쿼리스트링 key 제거(헤더만 사용) → URI 재작성/캐싱저해 방지 + // ExchangeFilterFunction apiKeyQueryFilter = ... <<-- 삭제 + + // ★ 요청/응답 타이밍 로깅(간단 버전) + ExchangeFilterFunction timing = ExchangeFilterFunction.ofRequestProcessor(req -> { + long start = System.nanoTime(); + return Mono.just( + ClientRequest.from(req) + .headers(h -> h.add("x-req-start-nanos", String.valueOf(start))) + .build() + ); + }).andThen(ExchangeFilterFunction.ofResponseProcessor(res -> { + return Mono.deferContextual(ctx -> { + // 필요시 로그프레임워크로 교체 + return Mono.just(res); + }); + })); + return WebClient.builder() .baseUrl(baseUrl) - .defaultHeader("Content-Type", "application/json") - .codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(1024 * 1024)) + .defaultHeader("x-goog-api-key", apiKey) // 헤더만 사용 + .defaultHeader(HttpHeaders.ACCEPT_ENCODING, "gzip") // 압축 명시 + .clientConnector(new ReactorClientHttpConnector(http)) + .defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .filter(timing) + .codecs(c -> { + c.defaultCodecs().maxInMemorySize(8 * 1024 * 1024); + c.defaultCodecs().jackson2JsonDecoder(new Jackson2JsonDecoder(objectMapper)); + c.defaultCodecs().jackson2JsonEncoder(new Jackson2JsonEncoder(objectMapper, MediaType.APPLICATION_JSON)); + }) .build(); } + + + public Integer getMaxContextTokens() { return maxContextTokens == null ? 8192 : maxContextTokens; } + public void setMaxContextTokens(Integer maxContextTokens) { this.maxContextTokens = maxContextTokens; } } diff --git a/back/src/main/java/com/back/global/ai/dto/gemini/VectorResponse.java b/back/src/main/java/com/back/global/ai/dto/gemini/VectorResponse.java new file mode 100644 index 0000000..22d2def --- /dev/null +++ b/back/src/main/java/com/back/global/ai/dto/gemini/VectorResponse.java @@ -0,0 +1,6 @@ +package com.back.global.ai.dto.gemini; + +public record VectorResponse( + String situation, + String recommendedOption +) {} \ No newline at end of file diff --git a/back/src/main/java/com/back/global/ai/service/AiServiceImpl.java b/back/src/main/java/com/back/global/ai/service/AiServiceImpl.java index ca03210..0bdd817 100644 --- a/back/src/main/java/com/back/global/ai/service/AiServiceImpl.java +++ b/back/src/main/java/com/back/global/ai/service/AiServiceImpl.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; import java.util.List; @@ -38,7 +39,7 @@ @Slf4j public class AiServiceImpl implements AiService { - private final TextAiClient textAiClient; + private final @Qualifier("gemini25TextClient") TextAiClient textAiClient; private final ObjectMapper objectMapper; private final SceneTypeRepository sceneTypeRepository; private final SituationAiProperties situationAiProperties; diff --git a/back/src/main/java/com/back/global/ai/vector/AIVectorServiceImpl.java b/back/src/main/java/com/back/global/ai/vector/AIVectorServiceImpl.java index 1beab61..f55133f 100644 --- a/back/src/main/java/com/back/global/ai/vector/AIVectorServiceImpl.java +++ b/back/src/main/java/com/back/global/ai/vector/AIVectorServiceImpl.java @@ -1,8 +1,7 @@ /* - * [파일 요약] - * - 인메모리 검색 결과(얇은 콘텍스트) + 이전 경로 요약으로 초경량 프롬프트를 만들고 - * TextAiClient 를 동기 호출하여 JSON 2필드(situation, recommendedOption)만 추출한다. - * - 토큰/콘텍스트 길이는 프로퍼티로 제어 가능. + * 이 파일은 pgvector 기반 얇은 콘텍스트와 이전 경로 요약으로 초경량 프롬프트를 생성하여 + * 제미나이를 동기 호출하고 JSON 2필드(situation, recommendedOption)만 추출한다. + * 라인 전환 섞임 방지를 위해 항상 lineId/age 윈도우 필터를 사용한다. */ package com.back.global.ai.vector; @@ -29,65 +28,81 @@ public class AIVectorServiceImpl implements AIVectorService { private final SituationAiProperties props; private final ObjectMapper objectMapper; - // 프로퍼티 바인딩 필드 - private int topK = 5; - private int contextCharLimit = 1000; - private int maxOutputTokens = 384; + // 프로퍼티 바인딩(기본값은 예시이며 yml로 조정) + private int topK = 1; + private int contextCharLimit = 200; + private int maxOutputTokens = 48; public void setTopK(int topK) { this.topK = topK; } public void setContextCharLimit(int contextCharLimit) { this.contextCharLimit = contextCharLimit; } public void setMaxOutputTokens(int maxOutputTokens) { this.maxOutputTokens = maxOutputTokens; } - /** 한 줄 요약: 경로 요약 + 얇은 콘텍스트로 프롬프트 최소화 후 AI 힌트 생성 */ + // 경로 요약 + 라인/나이 윈도우 RAG로 프롬프트를 최소화한 뒤 AI 힌트를 생성한다. @Override public AiNextHint generateNextHint(Long userId, Long decisionLineId, List orderedNodes) { if (orderedNodes == null || orderedNodes.isEmpty()) { return new AiNextHint(null, null); } - // 1) 질의/콘텍스트 준비 + int currAge = orderedNodes.get(orderedNodes.size() - 1).getAgeYear(); + + // 질의(경로 요약) String query = support.buildQueryFromNodes(orderedNodes); - List ctxSnippets = support.searchRelatedContexts(query, topK, Math.max(120, contextCharLimit / Math.max(1, topK))); + + // 관련 스니펫 상위 K + List ctxSnippets = support.searchRelatedContexts( + decisionLineId, currAge, query, topK, Math.max(120, contextCharLimit / Math.max(1, topK)) + ); String relatedContext = support.joinWithLimit(ctxSnippets, contextCharLimit); - // 2) 초경량 RAG 프롬프트 생성 + // 초경량 RAG 프롬프트 String prompt = buildRagPrompt(query, relatedContext); - // 3) 동기 호출 (CompletableFuture.join 사용) — 응답 즉시 필요 - AiRequest req = new AiRequest(prompt, Map.of(), Math.max(128, maxOutputTokens)); + // 제미나이 동기 호출(JSON 반환 유도 옵션 포함 권장) + AiRequest req = new AiRequest( + prompt, + Map.of( + "temperature", 0.2, + "topP", 0.9, + "topK", 1, + "candidateCount", 1, + "response_mime_type", "application/json" + ), + maxOutputTokens + ); String response = textAiClient.generateText(req).join(); - // 4) JSON 2필드만 추출 + // JSON 2필드 추출 String situation = SituationPrompt.extractSituation(response, objectMapper); String option = SituationPrompt.extractRecommendedOption(response, objectMapper); return new AiNextHint(emptyToNull(situation), emptyToNull(option)); } + // 프롬프트 문자열을 생성한다. private String buildRagPrompt(String previousSummary, String relatedContext) { String ctx = (relatedContext == null || relatedContext.isBlank()) ? "(관련 콘텍스트 없음)" : relatedContext; return """ - 당신은 인생 시뮬레이션 도우미입니다. - 아래의 '이전 선택 요약'과 '관련 콘텍스트'를 참고하여, - **동일 연도 시점**에서 자연스러운 새로운 상황을 **한 문장**으로 생성하세요. - - ## 이전 선택 요약 - %s - - ## 관련 콘텍스트(발췌) - %s - - ### 제약 - - 반드시 현재 베이스 상황과 동일한 연/시점 - - 과장/모호 금지, 구체적이고 현실적인 한 문장 - - 선택 분기가 필요 - - ### 응답(JSON) - { - "situation": "한 문장", - "recommendedOption": "15자 이내 선택지" - } - """.formatted(previousSummary, ctx); + 아래 규칙을 철저히 따르세요. + + [규칙] + - 출력은 딱 한 줄의 JSON만(개행·주석·설명 절대 금지) + - 모든 텍스트는 한국어만 사용(영문자 A~Z, a~z 금지) + - 허용 문자: 한글, 숫자, 공백, 기본 문장부호(.,!?\"'()-:;·…) + - 스키마: {"situation":"문장","recommendedOption":"15자 이내 선택지"} + - 값에 줄바꿈/탭/백틱/백슬래시 금지 + + [이전 선택 요약] + %s + + [관련 콘텍스트(발췌)] + %s + + [요구] + - 동일 연/시점의 자연스러운 새로운 상황을 한국어 한 문장으로 "situation"에 + - 한국어 15자 이내의 구체 선택지를 "recommendedOption"에 + - 예: {"situation":"장학금 발표가 내일로 다가왔다.","recommendedOption":"면접 대비 정리"} + """.formatted(previousSummary, ctx); } private String emptyToNull(String s) { diff --git a/back/src/main/java/com/back/global/ai/vector/AIVectorServiceSupportDomain.java b/back/src/main/java/com/back/global/ai/vector/AIVectorServiceSupportDomain.java index 50f09f7..36115bb 100644 --- a/back/src/main/java/com/back/global/ai/vector/AIVectorServiceSupportDomain.java +++ b/back/src/main/java/com/back/global/ai/vector/AIVectorServiceSupportDomain.java @@ -1,18 +1,25 @@ -/** - * [요약] 경로 요약/인메모리 검색 보조 유틸 +/* + * 이 파일은 경로 요약 문자열 생성, pgvector 기반 관련 스니펫 검색, 스니펫 병합 유틸리티를 제공한다. + * 항상 lineId와 현재 나이를 받아 라인 전환 시 섞임을 방지한다. */ package com.back.global.ai.vector; import com.back.domain.node.entity.DecisionNode; +import com.back.domain.search.entity.NodeSnippet; +import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Component; import java.util.*; import java.util.stream.Collectors; @Component +@RequiredArgsConstructor public class AIVectorServiceSupportDomain { - // 이전 결정 경로를 요약 쿼리 문자열로 만든다 + private final PgVectorSearchService vectorSearch; + private final EmbeddingClient embeddingClient; + + // 이전 결정 경로를 간단한 요약 문자열로 만든다. public String buildQueryFromNodes(List nodes) { return nodes.stream() .map(n -> String.format("- (%d세) %s → %s", @@ -22,19 +29,27 @@ public String buildQueryFromNodes(List nodes) { .collect(Collectors.joining("\n")); } - // 간단한 인메모리 유사 검색으로 상위 K 콘텍스트를 수집한다 (스텁) - public List searchRelatedContexts(String query, int topK, int eachSnippetLimit) { - return Collections.emptyList(); // 추후 RAM 인덱스 교체 지점 + // 라인/나이 윈도우로 제한하여 관련 스니펫을 상위 K개 가져온다. + public List searchRelatedContexts(Long lineId, int currAge, String query, int topK, int eachSnippetLimit) { + float[] qEmb = embeddingClient.embed(query); + List top = vectorSearch.topK(lineId, currAge, 2, qEmb, Math.max(topK, 1)); + List out = new ArrayList<>(); + for (NodeSnippet s : top) { + String t = s.getText(); + if (t == null || t.isBlank()) continue; + out.add(trim(t, eachSnippetLimit)); + } + return out; } - // 여러 스니펫을 합치되 총 글자수 제한을 적용한다 + // 여러 스니펫을 결합하되 총 길이를 제한한다. public String joinWithLimit(List snippets, int totalCharLimit) { StringBuilder sb = new StringBuilder(); for (String s : snippets) { if (s == null || s.isBlank()) continue; if (sb.length() + s.length() + 1 > totalCharLimit) break; - if (!sb.isEmpty()) sb.append("\n"); - sb.append(trim(s, Math.min(s.length(), Math.max(50, totalCharLimit / 5)))); + if (sb.length() > 0) sb.append("\n"); + sb.append(s); } return sb.toString(); } diff --git a/back/src/main/java/com/back/global/ai/vector/EmbeddingClient.java b/back/src/main/java/com/back/global/ai/vector/EmbeddingClient.java new file mode 100644 index 0000000..3a3871a --- /dev/null +++ b/back/src/main/java/com/back/global/ai/vector/EmbeddingClient.java @@ -0,0 +1,11 @@ +/* + * 이 파일은 텍스트를 임베딩 벡터(float[])로 변환하는 클라이언트 인터페이스를 정의한다. + * 구현체는 OpenAI/Vertex/사내 API 등으로 자유롭게 교체 가능하다. + */ +package com.back.global.ai.vector; + +public interface EmbeddingClient { + + // 입력 텍스트를 임베딩 벡터로 변환한다. + float[] embed(String text); +} diff --git a/back/src/main/java/com/back/global/ai/vector/EmbeddingProperties.java b/back/src/main/java/com/back/global/ai/vector/EmbeddingProperties.java new file mode 100644 index 0000000..04c3279 --- /dev/null +++ b/back/src/main/java/com/back/global/ai/vector/EmbeddingProperties.java @@ -0,0 +1,17 @@ +/* + * 이 파일은 임베딩 관련 설정(dim)을 프로퍼티에서 주입받기 위한 구성 클래스를 제공한다. + * 기본값은 768이며, ai.embedding.dim 으로 변경할 수 있다. + */ +package com.back.global.ai.vector; + +import lombok.Getter; +import lombok.Setter; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.context.annotation.Configuration; + +@Configuration +@ConfigurationProperties(prefix = "ai.embedding") +@Getter @Setter +public class EmbeddingProperties { + private int dim = 768; +} diff --git a/back/src/main/java/com/back/global/ai/vector/LocalHashEmbeddingClient.java b/back/src/main/java/com/back/global/ai/vector/LocalHashEmbeddingClient.java new file mode 100644 index 0000000..69689f8 --- /dev/null +++ b/back/src/main/java/com/back/global/ai/vector/LocalHashEmbeddingClient.java @@ -0,0 +1,96 @@ +/* + * 이 파일은 외부 API 없이 텍스트를 고정 차원(float[])으로 변환하는 경량 임베딩 클라이언트를 제공한다. + * 토큰을 해시해서 차원에 매핑한 뒤 L2 정규화한다. 품질은 간이지만 개발/테스트/임시 운영에 충분하다. + */ +package com.back.global.ai.vector; + +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; + +import java.nio.charset.StandardCharsets; + +@Component +@RequiredArgsConstructor +public class LocalHashEmbeddingClient implements EmbeddingClient { + + private final EmbeddingProperties props; + + // 텍스트를 고정 차원 해시 임베딩으로 변환한다. + @Override + public float[] embed(String text) { + int dim = Math.max(32, props.getDim()); + float[] v = new float[dim]; + if (text == null || text.isBlank()) return v; + + String[] toks = text.toLowerCase() + .replaceAll("[^\\p{L}\\p{Nd}\\s]", " ") + .trim() + .split("\\s+"); + + for (String t : toks) { + if (t.isBlank()) continue; + int h = murmur32(t.getBytes(StandardCharsets.UTF_8)); + int idx = Math.floorMod(h, dim); + v[idx] += 1.0f; + } + l2NormalizeInPlace(v); + return v; + } + + // L2 정규화 수행 + private void l2NormalizeInPlace(float[] v) { + double s = 0.0; + for (float x : v) s += x * x; + if (s == 0) return; + float inv = (float) (1.0 / Math.sqrt(s)); + for (int i = 0; i < v.length; i++) v[i] *= inv; + } + + // 간단한 MurmurHash3(32-bit) 구현 + private int murmur32(byte[] data) { + int c1 = 0xcc9e2d51; + int c2 = 0x1b873593; + int r1 = 15; + int r2 = 13; + int m = 5; + int n = 0xe6546b64; + + int hash = 0; + int len4 = data.length / 4; + + for (int i = 0; i < len4; i++) { + int i4 = i * 4; + int k = (data[i4] & 0xff) | ((data[i4 + 1] & 0xff) << 8) + | ((data[i4 + 2] & 0xff) << 16) | (data[i4 + 3] << 24); + k *= c1; + k = Integer.rotateLeft(k, r1); + k *= c2; + + hash ^= k; + hash = Integer.rotateLeft(hash, r2) * m + n; + } + + int idx = len4 * 4; + int k1 = 0; + switch (data.length & 3) { + case 3 -> k1 = (data[idx + 2] & 0xff) << 16; + case 2 -> k1 |= (data[idx + 1] & 0xff) << 8; + case 1 -> { + k1 |= (data[idx] & 0xff); + k1 *= c1; + k1 = Integer.rotateLeft(k1, r1); + k1 *= c2; + hash ^= k1; + } + } + + hash ^= data.length; + hash ^= (hash >>> 16); + hash *= 0x85ebca6b; + hash ^= (hash >>> 13); + hash *= 0xc2b2ae35; + hash ^= (hash >>> 16); + + return hash; + } +} diff --git a/back/src/main/java/com/back/global/ai/vector/PgVectorSearchService.java b/back/src/main/java/com/back/global/ai/vector/PgVectorSearchService.java new file mode 100644 index 0000000..f6f77ac --- /dev/null +++ b/back/src/main/java/com/back/global/ai/vector/PgVectorSearchService.java @@ -0,0 +1,47 @@ +/* + * 이 파일은 pgvector 리포지토리를 감싸 라인/나이 윈도우로 유사 스니펫을 조회하는 서비스를 제공한다. + * 쿼리 임베딩은 float[]로 받아 SQL 캐스트 가능한 문자열 "[...]"로 변환한다. + */ +package com.back.global.ai.vector; + +import com.back.domain.search.entity.NodeSnippet; +import com.back.domain.search.repository.NodeSnippetRepository; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service +@RequiredArgsConstructor +public class PgVectorSearchService { + + private final NodeSnippetRepository repo; + + // 라인/나이 윈도우 필터 + pgvector 유사도 검색을 수행한다. + public List topK(Long lineId, int currAge, int deltaAge, float[] queryEmbedding, int k) { + String q = toVectorLiteral(queryEmbedding); + int minAge = currAge - deltaAge; + int maxAge = currAge + deltaAge; + return repo.searchTopKByLineAndAgeWindow(lineId, minAge, maxAge, q, k); + } + + public List topKText(Long lineId, int currAge, int deltaAge, float[] queryEmbedding, int k) { + String q = toVectorLiteral(queryEmbedding); + int minAge = currAge - deltaAge; + int maxAge = currAge + deltaAge; + return repo.searchTopKTextByLineAndAgeWindow(lineId, minAge, maxAge, q, k); + } + + // float[] 임베딩을 "[a,b,c]" 형식으로 변환한다. + private String toVectorLiteral(float[] v) { + if (v == null || v.length == 0) return "[]"; + StringBuilder sb = new StringBuilder(v.length * 8 + 2); + sb.append('['); + for (int i = 0; i < v.length; i++) { + if (i > 0) sb.append(','); + sb.append(Float.toString(v[i])); + } + sb.append(']'); + return sb.toString(); + } +} diff --git a/back/src/main/java/com/back/global/initdata/InitData.java b/back/src/main/java/com/back/global/initdata/InitData.java index 5e1278e..fafb9ce 100644 --- a/back/src/main/java/com/back/global/initdata/InitData.java +++ b/back/src/main/java/com/back/global/initdata/InitData.java @@ -3,8 +3,6 @@ import com.back.domain.comment.entity.Comment; import com.back.domain.comment.repository.CommentRepository; import com.back.domain.node.dto.PivotListDto; -import lombok.extern.slf4j.Slf4j; -import org.springframework.context.annotation.Profile; import com.back.domain.node.dto.base.BaseLineBulkCreateRequest; import com.back.domain.node.dto.base.BaseLineBulkCreateResponse; import com.back.domain.node.dto.decision.DecNodeDto; @@ -19,22 +17,20 @@ import com.back.domain.post.entity.Post; import com.back.domain.post.enums.PostCategory; import com.back.domain.post.repository.PostRepository; -import com.back.domain.scenario.entity.SceneCompare; -import com.back.domain.scenario.entity.SceneCompareResultType; -import com.back.domain.scenario.entity.SceneType; -import com.back.domain.scenario.entity.Scenario; -import com.back.domain.scenario.entity.ScenarioStatus; -import com.back.domain.scenario.entity.Type; +import com.back.domain.scenario.entity.*; +import com.back.domain.scenario.repository.ScenarioRepository; import com.back.domain.scenario.repository.SceneCompareRepository; import com.back.domain.scenario.repository.SceneTypeRepository; -import com.back.domain.scenario.repository.ScenarioRepository; import com.back.domain.user.entity.Gender; import com.back.domain.user.entity.Mbti; import com.back.domain.user.entity.Role; import com.back.domain.user.entity.User; import com.back.domain.user.repository.UserRepository; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.springframework.boot.CommandLineRunner; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Profile; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.stereotype.Component; @@ -54,6 +50,7 @@ @Component @Profile("!prod") // prod 프로파일에서는 실행 안 함 @RequiredArgsConstructor +@ConditionalOnProperty(name = "app.initdata.enabled", havingValue = "true", matchIfMissing = true) public class InitData implements CommandLineRunner { private final UserRepository userRepository; diff --git a/back/src/main/java/com/back/infra/pgvector/PgVectorConverter.java b/back/src/main/java/com/back/infra/pgvector/PgVectorConverter.java new file mode 100644 index 0000000..8215784 --- /dev/null +++ b/back/src/main/java/com/back/infra/pgvector/PgVectorConverter.java @@ -0,0 +1,69 @@ +/* + * [파일 요약/코드 흐름] + * - JPA <-> PostgreSQL(pgvector) 매핑 컨버터 + * - DB 저장 시 float[]를 PGobject(type="vector")로 바인딩하여 드라이버가 네이티브 vector 타입으로 전달 + * - DB 조회 시 PGobject 또는 문자열("[a,b,c]")을 안전하게 파싱해 float[]로 복원 + * - 문자열로 바인딩할 때 발생하던 "column is of type vector but expression is of type character varying" 오류를 제거 + */ +package com.back.infra.pgvector; + +import jakarta.persistence.AttributeConverter; +import jakarta.persistence.Converter; +import org.postgresql.util.PGobject; + +@Converter(autoApply = false) +public class PgVectorConverter implements AttributeConverter { + + // 가장 중요한 함수: float[] -> PGobject(vector)로 직렬화해 네이티브 타입 바인딩 + @Override + public Object convertToDatabaseColumn(float[] attribute) { + if (attribute == null || attribute.length == 0) return null; // NOT NULL 컬럼이면 상위에서 보장 + try { + PGobject obj = new PGobject(); + obj.setType("vector"); // pgvector 타입 지정 + obj.setValue(toLiteral(attribute)); // "[a,b,c]" 형식 값 설정 + return obj; + } catch (Exception e) { + throw new IllegalArgumentException("Failed to convert float[] to PGobject(vector)", e); + } + } + + // 가장 많이 사용하는 함수: PGobject/문자열 -> float[]로 역직렬화 + @Override + public float[] convertToEntityAttribute(Object dbData) { + if (dbData == null) return new float[0]; + String s; + if (dbData instanceof PGobject pgo) { + s = pgo.getValue(); + } else { + s = dbData.toString(); + } + if (s == null) return new float[0]; + s = s.trim(); + if (s.isEmpty() || "[]".equals(s)) return new float[0]; + + if (s.startsWith("[") && s.endsWith("]")) { + s = s.substring(1, s.length() - 1); + } + if (s.isBlank()) return new float[0]; + + String[] parts = s.split("\\s*,\\s*"); + float[] out = new float[parts.length]; + for (int i = 0; i < parts.length; i++) { + out[i] = Float.parseFloat(parts[i]); + } + return out; + } + + // 가장 중요한 함수: float[]를 pgvector 리터럴("[...]") 문자열로 변환 + private String toLiteral(float[] v) { + StringBuilder sb = new StringBuilder(v.length * 8 + 2); + sb.append('['); + for (int i = 0; i < v.length; i++) { + if (i > 0) sb.append(','); + sb.append(Float.toString(v[i])); + } + sb.append(']'); + return sb.toString(); + } +} diff --git a/back/src/main/resources/application-test-pg.yml b/back/src/main/resources/application-test-pg.yml new file mode 100644 index 0000000..160241e --- /dev/null +++ b/back/src/main/resources/application-test-pg.yml @@ -0,0 +1,27 @@ +spring: + data: + redis: + host: localhost + port: 6379 + + datasource: + + jpa: + properties: + hibernate: + dialect: org.hibernate.dialect.PostgreSQLDialect + hibernate: + ddl-auto: create-drop + + flyway: + enabled: true + locations: classpath:db/mig + + h2: + console: + enabled: false # H2는 안쓰므로 끔 + +management: + health: + redis: + enabled: false diff --git a/back/src/main/resources/application.yml b/back/src/main/resources/application.yml index c60f398..045dab9 100644 --- a/back/src/main/resources/application.yml +++ b/back/src/main/resources/application.yml @@ -89,10 +89,12 @@ ai: gemini: api-key: ${GEMINI_API_KEY} model: gemini-2.5-flash + model20: gemini-2.0-flash base-url: https://generativelanguage.googleapis.com timeout-seconds: 70 # AI 응답 대기 시간 (시나리오 생성: 30-40초 + 여유 30초) max-retries: 0 # 재시도 비활성화 (타임아웃 방지) retry-delay-seconds: 2 # 재시도 간격 (초) + max-context-tokens: 8192 image: enabled: true provider: stable-diffusion @@ -116,6 +118,8 @@ ai: decision-scenario: maxOutputTokens: 16384 # 8192 -> 16384 (gemini-2.5-flash 최대 65536, 충분한 여유) timeout-seconds: 60 # 결정 시나리오 생성 타임아웃 (실제: 30-40초 + 여유) + embedding: + dim: 768 server: servlet: diff --git a/back/src/main/resources/db/mig/V1__init_pgvector.sql b/back/src/main/resources/db/mig/V1__init_pgvector.sql new file mode 100644 index 0000000..f92338e --- /dev/null +++ b/back/src/main/resources/db/mig/V1__init_pgvector.sql @@ -0,0 +1,18 @@ +-- pgvector 확장 +CREATE EXTENSION IF NOT EXISTS vector; + +-- 임베딩 차원은 실제 모델에 맞추세요 (예: 768/1024/1536...) +CREATE TABLE IF NOT EXISTS node_snippet ( + id BIGSERIAL PRIMARY KEY, + line_id BIGINT NOT NULL, + age_year INT NOT NULL, + text TEXT, + embedding VECTOR(768) + ); + +CREATE INDEX IF NOT EXISTS ix_node_snippet_line_age + ON node_snippet(line_id, age_year); + +-- 코사인 거리 KNN 인덱스 +CREATE INDEX IF NOT EXISTS ix_node_snippet_embedding + ON node_snippet USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); diff --git a/back/src/main/resources/db/migration/V7__add_vector_db.sql b/back/src/main/resources/db/migration/V7__add_vector_db.sql new file mode 100644 index 0000000..e83337f --- /dev/null +++ b/back/src/main/resources/db/migration/V7__add_vector_db.sql @@ -0,0 +1,27 @@ +-- ============================================== +-- pgvector 기반 node_snippet 테이블 생성 +-- ============================================== + +-- 1. pgvector extension 활성화 +-- 이미 활성화 되어 있을 수 있으므로 IF NOT EXISTS 사용 +CREATE EXTENSION IF NOT EXISTS vector; + +-- 2. node_snippet 테이블 생성 +CREATE TABLE IF NOT EXISTS node_snippet ( + id BIGINT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, + line_id BIGINT NOT NULL, + age_year INT NOT NULL, + category VARCHAR(50), + text TEXT NOT NULL, + embedding VECTOR(768) NOT NULL +); + +-- 3. 검색 성능을 위한 인덱스 생성 +CREATE INDEX IF NOT EXISTS idx_node_snippet_line_age + ON node_snippet(line_id, age_year); + +-- 코사인 거리 기반 KNN 인덱스 생성 +-- hnsw 인덱스 성능이 좋지만 db.t3.micro 상 메모리 문제로 인해 ivfflat 사용 +CREATE INDEX IF NOT EXISTS idx_node_snippet_embedding_ivfflat + ON node_snippet USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 50) \ No newline at end of file diff --git a/back/src/test/java/com/back/global/ai/vector/AiLatencyProbeConfig.java b/back/src/test/java/com/back/global/ai/vector/AiLatencyProbeConfig.java new file mode 100644 index 0000000..7cc979e --- /dev/null +++ b/back/src/test/java/com/back/global/ai/vector/AiLatencyProbeConfig.java @@ -0,0 +1,45 @@ +/* + * [코드 흐름 요약] + * - Gemini WebClient에 타이밍 필터를 주입해 실제 HTTP 왕복 시간을 ms로 기록한다. + * - 기록값은 LAST_LATENCY_MS(AtomicLong)에 저장되어 테스트에서 읽어 확인한다. + * - 기존 geminiWebClient 빈을 @Primary로 대체하므로 실제 호출 경로는 그대로 유지된다. + */ +package com.back.global.ai.vector; + +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.WebClient; + +import java.util.concurrent.atomic.AtomicLong; + +@TestConfiguration +public class AiLatencyProbeConfig { + + // 가장 중요한 함수 위에 한줄로만 + public static final AtomicLong LAST_LATENCY_MS = new AtomicLong(-1); + + // 가장 많이 사용되는 함수 호출 위에 한줄로만 + @Bean + @Primary + @Qualifier("geminiWebClient") + WebClient timedGeminiWebClient(WebClient.Builder builder, + @Value("${ai.text.gemini.base-url}") String baseUrl) { + + ExchangeFilterFunction timingFilter = (request, next) -> { + long t0 = System.nanoTime(); + LAST_LATENCY_MS.set(-1); // 시작 시 초기화 + return next.exchange(request) + .doOnSuccess(resp -> LAST_LATENCY_MS.set((System.nanoTime() - t0) / 1_000_000)) + .doOnError(e -> LAST_LATENCY_MS.set((System.nanoTime() - t0) / 1_000_000)); + }; + + return builder + .baseUrl(baseUrl) + .filter(timingFilter) + .build(); + } +} diff --git a/back/src/test/java/com/back/global/ai/vector/DecisionAssistVectorFlowIT.java b/back/src/test/java/com/back/global/ai/vector/DecisionAssistVectorFlowIT.java new file mode 100644 index 0000000..4222eaa --- /dev/null +++ b/back/src/test/java/com/back/global/ai/vector/DecisionAssistVectorFlowIT.java @@ -0,0 +1,222 @@ +///* +// * [코드 흐름 요약] +// * 1) Docker 가능 시 pgvector 컨테이너(pg16) 선기동 → DynamicPropertySource로 Postgres 주입 +// * 2) NodeSnippet을 저장해 pgvector(vector 타입) 라운드트립/유사도 정렬/윈도우 필터를 검증 +// * 3) AIVectorService.generateNextHint를 호출해 지원도메인→pgvector검색→프롬프트 구성→AI 응답 반영 흐름을 캡처/검증 +// */ +//package com.back.global.ai.vector; +// +//import com.back.domain.node.controller.AiCallBudget; +//import com.back.domain.node.controller.AiOnceDelegateTestConfig; +//import com.back.domain.node.entity.DecisionNode; +//import com.back.domain.search.entity.NodeSnippet; +//import com.back.domain.search.repository.NodeSnippetRepository; +//import com.back.global.ai.client.text.TextAiClient; +//import com.back.global.ai.dto.AiRequest; +//import org.junit.jupiter.api.Assumptions; +//import org.junit.jupiter.api.BeforeEach; +//import org.junit.jupiter.api.DisplayName; +//import org.junit.jupiter.api.Test; +//import org.mockito.ArgumentCaptor; +//import org.springframework.beans.factory.annotation.Autowired; +//import org.springframework.boot.test.context.SpringBootTest; +//import org.springframework.boot.test.mock.mockito.MockBean; +//import org.springframework.context.annotation.Import; +//import org.springframework.jdbc.core.JdbcTemplate; +//import org.springframework.test.context.ActiveProfiles; +//import org.springframework.test.context.DynamicPropertyRegistry; +//import org.springframework.test.context.DynamicPropertySource; +//import org.springframework.test.context.TestPropertySource; +//import org.testcontainers.DockerClientFactory; +//import org.testcontainers.containers.PostgreSQLContainer; +//import org.testcontainers.junit.jupiter.Testcontainers; +//import org.testcontainers.utility.DockerImageName; +// +//import java.time.LocalDateTime; +//import java.util.List; +//import java.util.concurrent.CompletableFuture; +// +//import static org.assertj.core.api.Assertions.assertThat; +//import static org.assertj.core.api.Assertions.within; +//import static org.mockito.ArgumentMatchers.any; +//import static org.mockito.Mockito.*; +// +//@Testcontainers(disabledWithoutDocker = true) +//@SpringBootTest +//@ActiveProfiles("test") +//@Import(AiOnceDelegateTestConfig.class) +//@TestPropertySource(properties = { +// "spring.jpa.open-in-view=false", +// "spring.flyway.enabled=false", +// // 테스트 임베딩 차원(엔티티 DDL vector(768)과 맞춘다) +// "ai.embedding.dim=768" +//}) +//class DecisionAssistVectorFlowIT { +// +// private static final boolean DOCKER; +// static { +// boolean ok; +// try { ok = DockerClientFactory.instance().isDockerAvailable(); } +// catch (Throwable t) { ok = false; } +// DOCKER = ok; +// } +// +// // 가장 중요한 함수 위에 한줄로만 +// static final PostgreSQLContainer POSTGRES = +// new PostgreSQLContainer<>( +// DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres") +// ) +// .withDatabaseName("relife_test") +// .withUsername("test") +// .withPassword("test") +// .withInitScript("sql/init_vector.sql"); // CREATE EXTENSION IF NOT EXISTS vector; +// +// // 컨테이너 선기동(프로퍼티 평가 전에) +// static { if (DOCKER) POSTGRES.start(); } +// +// // 가장 중요한 함수 위에 한줄로만 +// @DynamicPropertySource +// static void props(DynamicPropertyRegistry r) { +// Assumptions.assumeTrue(DOCKER, "Docker not available — skipping pgvector IT"); +// r.add("spring.datasource.url", POSTGRES::getJdbcUrl); +// r.add("spring.datasource.username", POSTGRES::getUsername); +// r.add("spring.datasource.password", POSTGRES::getPassword); +// r.add("spring.datasource.driver-class-name", () -> "org.postgresql.Driver"); +// r.add("spring.jpa.hibernate.ddl-auto", () -> "update"); +// r.add("spring.jpa.properties.hibernate.dialect", () -> "org.hibernate.dialect.PostgreSQLDialect"); +// } +// +// @Autowired EmbeddingClient embeddingClient; +// @Autowired PgVectorSearchService vectorSearch; +// @Autowired AIVectorService aivectorService; // test 프로파일에선 AiOnceDelegateTestConfig로 1회 실구현/이후 스텁 가능 +// @Autowired AIVectorServiceSupportDomain support; +// @Autowired NodeSnippetRepository snippetRepo; +// @Autowired JdbcTemplate jdbc; +// @Autowired +// AiCallBudget budget; // 1회 실호출 예산 +// +// @MockBean TextAiClient textAiClient; +// +// @BeforeEach +// void clearTable() { +// Assumptions.assumeTrue(DOCKER); +// // 간단 초기화(테스트 격리) +// jdbc.execute("DELETE FROM node_snippet"); +// } +// +// @Test +// @DisplayName("pgvector 라운드트립: float[] 저장→vector 타입 확인→재조회") +// void vector_roundtrip_and_type() { +// Assumptions.assumeTrue(DOCKER); +// +// float[] v = embeddingClient.embed("라운드트립 검증 문장"); +// // 가장 많이 사용하는 함수 호출 위에 한줄로만 +// saveSnippet(100L, 20, "roundtrip", "TEST", v); +// +// String typ = jdbc.queryForObject("select pg_typeof(embedding)::text from node_snippet limit 1", String.class); +// assertThat(typ).isEqualTo("vector"); +// +// float[] back = snippetRepo.findAll().get(0).getEmbedding(); +// assertThat(back.length).isEqualTo(v.length); +// double l2 = 0; for (float x : back) l2 += x*x; +// assertThat(Math.sqrt(l2)).isCloseTo(1.0, within(1e-3)); // L2 정규화 유지 +// } +// +// @Test +// @DisplayName("리포지토리 네이티브 쿼리: line/age 윈도우 + <=> 유사도 정렬 + LIMIT") +// void repository_age_window_similarity_limit() { +// Assumptions.assumeTrue(DOCKER); +// +// long lineId = 123L; int anchor = 30; +// saveSnippet(lineId, 28, "주거 이전 비용과 출퇴근 시간을 계산한다.", "LIFE"); +// saveSnippet(lineId, 30, "이직 제안을 받고 면접을 준비한다.", "CAREER"); +// saveSnippet(lineId, 32, "대학원 진학을 고민한다.", "EDU"); +// saveSnippet(999L, 30, "다른 라인 데이터", "NOPE"); +// +// float[] q = embeddingClient.embed("- (30세) 이직 및 면접 준비 사항 점검"); +// String literal = toVectorLiteral(q); +// +// List top = snippetRepo.searchTopKByLineAndAgeWindow(lineId, anchor - 2, anchor + 2, literal, 2); +// +// assertThat(top).hasSize(2); +// assertThat(top.stream().allMatch(s -> s.getLineId().equals(lineId))).isTrue(); +// assertThat(top.stream().allMatch(s -> Math.abs(s.getAgeYear() - anchor) <= 2)).isTrue(); +// } +// +// // 가장 중요한 함수 위에 한줄로만 +// @Test +// @DisplayName("AIVectorService: 지원도메인→pgvector 검색 포함 프롬프트 구성 및 AI 응답 반영") +// void service_flow_prompt_and_aihint() { +// Assumptions.assumeTrue(DOCKER); +// +// long lineId = 777L; int age = 22; +// saveSnippet(lineId, age, "수도권 컴공 진학 비용과 통학 고민", "EDUCATION"); +// saveSnippet(lineId, age, "서울 스타트업 인턴 이력서 준비", "CAREER"); +// +// // 🔴 호출 전에 '실호출 1회' 예산 설정 (이게 없으면 스텁 경로로 빠져서 캡처가 안 잡힘) +// budget.reset(1); +// +// // 응답 더미(실제 값은 verify 이후에만 캡처) +// when(textAiClient.generateText(any(AiRequest.class))) +// .thenReturn(CompletableFuture.completedFuture( +// "{\"situation\":\"한 문장 상황\",\"recommendedOption\":\"짧은선택\"}" +// )); +// +// // orderedNodes 최소 구성 +// DecisionNode n = DecisionNode.builder() +// .ageYear(age).situation("대학 진학을 앞두고 컴공 고려").decision("컴공 선택") +// .build(); +// +// // 가장 많이 사용하는 함수 호출 위에 한줄로만 +// var hint = aivectorService.generateNextHint(1L, lineId, List.of(n)); +// +// // ✅ verify로 캡처 (stubbing에서 capture하지 말고, 호출 '후' 검증에서 capture) +// ArgumentCaptor reqCap = ArgumentCaptor.forClass(AiRequest.class); +// verify(textAiClient, timeout(5000)).generateText(reqCap.capture()); +// AiRequest sent = reqCap.getValue(); +// +// // 프롬프트/옵션 검증 +// assertThat(sent.parameters()).containsEntry("response_mime_type", "application/json"); +// assertThat(sent.maxTokens()).isGreaterThan(40); +// assertThat(sent.prompt()) +// .contains("이전 선택 요약") +// .contains("관련 콘텍스트") +// .contains("컴공") +// .contains("인턴"); +// +// // 서비스 반환값 검증 +// assertThat(hint.aiNextSituation()).isEqualTo("한 문장 상황"); +// assertThat(hint.aiNextRecommendedOption()).isEqualTo("짧은선택"); +// } +// +// +// // --- 헬퍼 --- +// +// // 가장 중요한 함수 위에 한줄로만 +// private void saveSnippet(Long lineId, int age, String text, String category, float[] emb) { +// NodeSnippet s = NodeSnippet.builder() +// .lineId(lineId) +// .ageYear(age) +// .category(category) +// .text(text) +// .embedding(emb != null ? emb : embeddingClient.embed(text)) +// .updatedAt(LocalDateTime.now()) +// .build(); +// snippetRepo.save(s); +// } +// +// private void saveSnippet(Long lineId, int age, String text, String category) { +// saveSnippet(lineId, age, text, category, null); +// } +// +// private String toVectorLiteral(float[] v) { +// StringBuilder sb = new StringBuilder(v.length * 8 + 2); +// sb.append('['); +// for (int i = 0; i < v.length; i++) { +// if (i > 0) sb.append(','); +// sb.append(Float.toString(v[i])); +// } +// sb.append(']'); +// return sb.toString(); +// } +//} diff --git a/back/src/test/java/com/back/global/ai/vector/DecisionAssistVectorLatencyIT.java b/back/src/test/java/com/back/global/ai/vector/DecisionAssistVectorLatencyIT.java new file mode 100644 index 0000000..6b6292e --- /dev/null +++ b/back/src/test/java/com/back/global/ai/vector/DecisionAssistVectorLatencyIT.java @@ -0,0 +1,199 @@ +///* +// * [코드 흐름 요약] +// * 1) 테스트 컨텍스트에서 AIVectorService를 실구현으로 강제 오버라이드(@Primary) +// * 2) AiLatencyProbeConfig로 Gemini HTTP 왕복 시간(ms) 계측 +// * 3) budget.reset(1)로 1회 실호출 강제, 호출 후 realAi=true + e2e/http 시간 출력 +// */ +//package com.back.global.ai.vector; +// +//import com.back.domain.node.controller.AiCallBudget; +//import com.back.domain.node.controller.AiOnceDelegateTestConfig; +//import com.back.domain.node.entity.DecisionNode; +//import com.back.domain.search.entity.NodeSnippet; +//import com.back.domain.search.repository.NodeSnippetRepository; +//import com.back.global.ai.client.text.TextAiClient; +//import com.back.global.ai.config.SituationAiProperties; +//import com.fasterxml.jackson.databind.ObjectMapper; +//import org.junit.jupiter.api.Assumptions; +//import org.junit.jupiter.api.BeforeEach; +//import org.junit.jupiter.api.DisplayName; +//import org.junit.jupiter.api.Test; +//import org.springframework.beans.factory.annotation.Autowired; +//import org.springframework.boot.test.context.SpringBootTest; +//import org.springframework.boot.test.context.TestConfiguration; +//import org.springframework.context.annotation.Bean; +//import org.springframework.context.annotation.Import; +//import org.springframework.test.context.ActiveProfiles; +//import org.springframework.test.context.DynamicPropertyRegistry; +//import org.springframework.test.context.DynamicPropertySource; +//import org.springframework.test.context.TestPropertySource; +//import org.testcontainers.DockerClientFactory; +//import org.testcontainers.containers.PostgreSQLContainer; +//import org.testcontainers.junit.jupiter.Testcontainers; +//import org.testcontainers.utility.DockerImageName; +// +//import java.time.LocalDateTime; +//import java.util.Arrays; +//import java.util.List; +// +//import static org.assertj.core.api.Assertions.assertThat; +// +//@Testcontainers(disabledWithoutDocker = true) +//@SpringBootTest +//@TestPropertySource(properties = { +// "app.initdata.enabled=false" +//}) +// +//@ActiveProfiles("test") // 스텁이 떠 있어도 아래 @Primary 오버라이드가 이긴다 +//@Import({DecisionAssistVectorLatencyIT.RealAiOverrideConfig.class, AiLatencyProbeConfig.class, +//AiOnceDelegateTestConfig.class}) +//class DecisionAssistVectorLatencyIT { +// +// // Docker 선기동 + pgvector +// static final boolean DOCKER = initDocker(); +// static boolean initDocker() { +// try { return DockerClientFactory.instance().isDockerAvailable(); } catch (Throwable t) { return false; } +// } +// static final PostgreSQLContainer PG = +// new PostgreSQLContainer<>(DockerImageName.parse("pgvector/pgvector:pg16") +// .asCompatibleSubstituteFor("postgres")) +// .withDatabaseName("relife_test") +// .withUsername("test") +// .withPassword("test") +// .withInitScript("sql/init_vector.sql"); +// static { if (DOCKER) PG.start(); } +// +// @DynamicPropertySource +// static void props(DynamicPropertyRegistry r) { +// Assumptions.assumeTrue(DOCKER, "Docker not available — skipping"); +// r.add("spring.datasource.url", PG::getJdbcUrl); +// r.add("spring.datasource.username", PG::getUsername); +// r.add("spring.datasource.password", PG::getPassword); +// r.add("spring.datasource.driver-class-name", () -> "org.postgresql.Driver"); +// r.add("spring.jpa.hibernate.ddl-auto", () -> "update"); +// r.add("spring.jpa.properties.hibernate.dialect", () -> "org.hibernate.dialect.PostgreSQLDialect"); +// +// // 실키 사용 — 환경변수 반드시 세팅: GEMINI_API_KEY +// String apiKey = System.getenv("GEMINI_API_KEY"); +// if (apiKey == null || apiKey.isBlank()) { +// apiKey = System.getProperty("GEMINI_API_KEY"); +// } +// Assumptions.assumeTrue(apiKey != null && !apiKey.isBlank(), "GEMINI_API_KEY not set — skipping"); +// String finalApiKey = apiKey; +// r.add("ai.text.gemini.api-key", () -> finalApiKey); +// r.add("ai.text.gemini.model", () -> "gemini-2.0-flash"); +// r.add("ai.text.gemini.base-url", () -> "https://generativelanguage.googleapis.com"); +// } +// +// @Autowired EmbeddingClient embeddingClient; +// @Autowired PgVectorSearchService vectorSearch; +// @Autowired AIVectorService aivectorService; // 오버라이드된 실구현 주입 +// @Autowired DecisionAssistVectorFlowITTestOps ops; // 단순 세이브 헬퍼 +// +// @Autowired +// AiCallBudget budget; // 예산 주입 +// +// @BeforeEach +// void setup() { +// ops.clearSnippets(); +// } +// +// @Test +// @DisplayName("실제 Gemini 호출 레이턴시 측정(E2E / HTTP)") +// void latency_real_ai() { +// // 1) 실호출 1회 강제 +// budget.reset(1); +// +// // 2) RAG 콘텍스트 준비 +// long lineId = 900L; int age = 22; +// ops.saveSnippet(lineId, age, "수도권 컴공 진학 비용과 통학 고민", "EDUCATION"); +// ops.saveSnippet(lineId, age, "서울 스타트업 인턴 이력서 준비", "CAREER"); +// +// // 3) 호출 +// long t0 = System.nanoTime(); +// DecisionNode last = DecisionNode.builder().ageYear(age).situation("컴공 고려").decision("컴공 선택").build(); +// var hint = aivectorService.generateNextHint(1L, lineId, List.of(last)); +// long e2eMs = (System.nanoTime() - t0) / 1_000_000; +// +// long httpMs = AiLatencyProbeConfig.LAST_LATENCY_MS.get(); +// boolean realAi = httpMs >= 0; // 데코레이터가 시간 기록했으면 실호출 +// +// System.out.println("[LAT] realAi=" + realAi + " e2eMs=" + e2eMs + " httpMs=" + httpMs + +// " situation=" + hint.aiNextSituation() + " option=" + hint.aiNextRecommendedOption()); +// +// assertThat(realAi).isTrue(); // 반드시 실호출이어야 함 +// assertThat(hint.aiNextSituation()).isNotBlank(); +// } +// +// @Test +// @DisplayName("워밍업 후 레이턴시 P50 검증(상황 생성 경로)") +// void latency_after_warmup_p50() { +// // next 노드 생성: 워밍업 1회(최소 출력·짧은 프롬프트 권장) +// budget.reset(1); // 무결성 검증 +// aivectorService.generateNextHint(1L, 900L, List.of( +// DecisionNode.builder().ageYear(22).situation("컴공 고려").decision("컴공 선택").build() +// )); +// +// // 측정 N회 +// int N = 5; +// long[] e2e = new long[N]; +// long[] http = new long[N]; +// for (int i = 0; i < N; i++) { +// budget.reset(1); // 무결성 검증 +// long t0 = System.nanoTime(); +// var hint = aivectorService.generateNextHint(1L, 900L, List.of( +// DecisionNode.builder().ageYear(22).situation("컴공 고려").decision("컴공 선택").build() +// )); +// long e2eMs = (System.nanoTime() - t0) / 1_000_000; +// long httpMs = AiLatencyProbeConfig.LAST_LATENCY_MS.get(); +// e2e[i] = e2eMs; +// http[i] = httpMs; +// } +// +// Arrays.sort(e2e); +// Arrays.sort(http); +// long p50E2E = e2e[N/2]; +// long p50HTTP = http[N/2]; +// System.out.println("[P50] e2e=" + p50E2E + "ms, http=" + p50HTTP + "ms"); +// +// // 무결성 검증: 목표 상한(예: 1200ms)을 테스트 기준으로 잡아두기 +// assertThat(p50HTTP).isLessThan(1200); +// } +// +// // --- 테스트용 헬퍼/오버라이드 --- +// +// @TestConfiguration +// static class DecisionAssistVectorFlowITTestOps { +// private final NodeSnippetRepository repo; +// private final EmbeddingClient emb; +// DecisionAssistVectorFlowITTestOps(NodeSnippetRepository repo, EmbeddingClient emb) { +// this.repo = repo; this.emb = emb; +// } +// // 가장 많이 사용하는 함수 호출 위에 한줄로만 +// void clearSnippets() { repo.deleteAllInBatch(); } +// // 가장 중요한 함수 위에 한줄로만 +// void saveSnippet(Long lineId, int age, String text, String category) { +// repo.save(NodeSnippet.builder() +// .lineId(lineId).ageYear(age).category(category) +// .text(text).embedding(emb.embed(text)).updatedAt(LocalDateTime.now()) +// .build()); +// } +// } +// +// @TestConfiguration +// static class RealAiOverrideConfig { +// // 한줄 요약: AIVectorService를 실구현으로 강제(@Primary)해서 test 스텁을 덮어쓴다 +// @Bean +// AIVectorService realAIVectorService( +// TextAiClient textAiClient, +// AIVectorServiceSupportDomain support, +// SituationAiProperties props, +// ObjectMapper objectMapper +// ) { +// var impl = new AIVectorServiceImpl(textAiClient, support, props, objectMapper); +// impl.setTopK(1); impl.setContextCharLimit(160); impl.setMaxOutputTokens(48); +// return impl; +// } +// +// } +//} diff --git a/back/src/test/java/com/back/global/ai/vector/PgVectorSearchServiceIT.java b/back/src/test/java/com/back/global/ai/vector/PgVectorSearchServiceIT.java new file mode 100644 index 0000000..dc91daa --- /dev/null +++ b/back/src/test/java/com/back/global/ai/vector/PgVectorSearchServiceIT.java @@ -0,0 +1,112 @@ +///* +// * [코드 흐름 요약] +// * 1) 클래스 로딩 시 Docker 가용성 체크 → 가능하면 컨테이너를 static 블록에서 즉시 start() +// * 2) @DynamicPropertySource 에서는 컨테이너가 이미 시작된 상태에서 JDBC 프로퍼티를 등록 +// * 3) Docker 불가면 Assumptions로 테스트 전부를 "skipped" 처리(빨간불 방지) +// */ +//package com.back.global.ai.vector; +// +//import com.back.domain.search.entity.NodeSnippet; +//import com.back.domain.search.repository.NodeSnippetRepository; +//import org.junit.jupiter.api.Assumptions; +//import org.junit.jupiter.api.DisplayName; +//import org.junit.jupiter.api.Test; +//import org.junit.jupiter.api.TestInstance; +//import org.springframework.beans.factory.annotation.Autowired; +//import org.springframework.boot.test.context.SpringBootTest; +//import org.springframework.test.context.ActiveProfiles; +//import org.springframework.test.context.DynamicPropertyRegistry; +//import org.springframework.test.context.DynamicPropertySource; +//import org.testcontainers.DockerClientFactory; +//import org.testcontainers.containers.PostgreSQLContainer; +//import org.testcontainers.junit.jupiter.Testcontainers; +//import org.testcontainers.utility.DockerImageName; +// +//import java.time.LocalDateTime; +//import java.util.List; +// +//import static org.assertj.core.api.Assertions.assertThat; +// +//@Testcontainers(disabledWithoutDocker = true) +//@SpringBootTest +//@ActiveProfiles("test") +//@TestInstance(TestInstance.Lifecycle.PER_CLASS) +//class PgVectorSearchServiceIT { +// +// private static final boolean DOCKER_AVAILABLE; +// static { +// boolean ok; +// try { ok = DockerClientFactory.instance().isDockerAvailable(); } +// catch (Throwable t) { ok = false; } +// DOCKER_AVAILABLE = ok; +// } +// +// // 가장 중요한 함수 위에 한줄로만 +// static final PostgreSQLContainer postgres = +// new PostgreSQLContainer<>( +// DockerImageName.parse("pgvector/pgvector:pg16") +// .asCompatibleSubstituteFor("postgres") // 타입 세이프티 +// ) +// .withDatabaseName("relife_test") +// .withUsername("test") +// .withPassword("test") +// .withInitScript("sql/init_vector.sql"); // CREATE EXTENSION IF NOT EXISTS vector; +// +// // 컨테이너를 클래스 로딩 시점에 기동(프로퍼티 평가보다 먼저) +// static { +// if (DOCKER_AVAILABLE) { +// postgres.start(); +// } +// } +// +// // 가장 중요한 함수 위에 한줄로만 +// @DynamicPropertySource +// static void props(DynamicPropertyRegistry r) { +// Assumptions.assumeTrue(DOCKER_AVAILABLE, "Docker not available — skipping pgvector IT"); +// r.add("spring.datasource.url", postgres::getJdbcUrl); +// r.add("spring.datasource.username", postgres::getUsername); +// r.add("spring.datasource.password", postgres::getPassword); +// r.add("spring.datasource.driver-class-name", () -> "org.postgresql.Driver"); +// r.add("spring.jpa.hibernate.ddl-auto", () -> "update"); +// r.add("spring.jpa.properties.hibernate.dialect", () -> "org.hibernate.dialect.PostgreSQLDialect"); +// r.add("spring.jpa.open-in-view", () -> "false"); +// r.add("spring.flyway.enabled", () -> "false"); +// r.add("ai.embedding.dim", () -> 768); +// } +// +// @Autowired EmbeddingClient embeddingClient; +// @Autowired NodeSnippetRepository snippetRepo; +// @Autowired PgVectorSearchService vectorSearch; +// +// @Test +// @DisplayName("pgvector: 라인/나이 윈도우 + <=> 유사도 정렬 topK") +// void topK_should_respect_line_age_window_and_similarity_order() { +// Assumptions.assumeTrue(DOCKER_AVAILABLE, "Docker not available — skipping pgvector IT"); +// +// long lineId = 777L; +// int age = 22; +// +// // 가장 많이 사용하는 함수 호출 위에 한줄로만 +// saveSnippet(lineId, age, "수도권 컴퓨터공학과 진학을 고민한다. 등록금과 통학 문제를 따져본다.", "EDUCATION"); +// saveSnippet(lineId, age, "서울 스타트업 인턴을 제안받아 이력서를 준비한다.", "CAREER"); +// saveSnippet(lineId, age, "군 입대 전 체력 관리를 위해 헬스장을 등록했다.", "HEALTH"); +// +// String query = "- (22세) 대학 진학을 앞두고 컴퓨터공학과 선택과 비용을 고민한다"; +// float[] qEmb = embeddingClient.embed(query); +// +// List top = vectorSearch.topK(lineId, age, 1, qEmb, 2); +// +// assertThat(top).hasSize(2); +// assertThat(top.get(0).getCategory()).isIn("EDUCATION", "CAREER"); +// } +// +// // 가장 중요한 함수 위에 한줄로만 +// private void saveSnippet(Long lineId, int age, String text, String category) { +// float[] emb = embeddingClient.embed(text); +// NodeSnippet s = NodeSnippet.builder() +// .lineId(lineId).ageYear(age).category(category) +// .text(text).embedding(emb).updatedAt(LocalDateTime.now()) +// .build(); +// snippetRepo.save(s); +// } +//} diff --git a/back/src/test/resources/sql/init_vector.sql b/back/src/test/resources/sql/init_vector.sql new file mode 100644 index 0000000..e734541 --- /dev/null +++ b/back/src/test/resources/sql/init_vector.sql @@ -0,0 +1,4 @@ +-- [코드 흐름 요약] +-- 1) pgvector 확장을 설치한다. +-- 2) Hibernate가 이후에 vector(768) 컬럼을 포함한 테이블을 생성할 수 있게 준비한다. +CREATE EXTENSION IF NOT EXISTS vector;