|
1 | 1 | package io.sentrius.sso.controllers.api; |
2 | 2 |
|
3 | 3 | import io.sentrius.sso.core.annotations.LimitAccess; |
| 4 | +import io.sentrius.sso.core.config.SystemOptions; |
| 5 | +import io.sentrius.sso.core.controllers.BaseController; |
4 | 6 | import io.sentrius.sso.core.model.security.enums.ApplicationAccessEnum; |
| 7 | +import io.sentrius.sso.core.model.verbs.Endpoint; |
| 8 | +import io.sentrius.sso.core.promptadvisor.model.RefinePromptResponse; |
| 9 | +import io.sentrius.sso.core.promptadvisor.model.ValidatePromptRequest; |
| 10 | +import io.sentrius.sso.core.promptadvisor.service.PromptAdvisorService; |
| 11 | +import io.sentrius.sso.core.services.ErrorOutputService; |
| 12 | +import io.sentrius.sso.core.services.UserService; |
| 13 | +import jakarta.servlet.http.HttpServletRequest; |
| 14 | +import jakarta.servlet.http.HttpServletResponse; |
5 | 15 | import lombok.extern.slf4j.Slf4j; |
6 | 16 | import org.springframework.beans.factory.annotation.Value; |
7 | 17 | import org.springframework.http.*; |
|
17 | 27 | @Slf4j |
18 | 28 | @RestController |
19 | 29 | @RequestMapping("/api/v1/prompt-advisor") |
20 | | -public class PromptAdvisorApiController { |
| 30 | +public class PromptAdvisorApiController extends BaseController { |
21 | 31 |
|
22 | 32 | private final RestTemplate restTemplate = new RestTemplate(); |
| 33 | + private final PromptAdvisorService promptAdvisorService; |
| 34 | + |
23 | 35 |
|
24 | 36 | @Value("${sentrius.prompt-advisor.url:http://sentrius-prompt-advisor:80}") |
25 | 37 | private String promptAdvisorUrl; |
26 | 38 |
|
| 39 | + protected PromptAdvisorApiController( |
| 40 | + UserService userService, SystemOptions systemOptions, |
| 41 | + ErrorOutputService errorOutputService, PromptAdvisorService promptAdvisorService |
| 42 | + ) { |
| 43 | + super(userService, systemOptions, errorOutputService); |
| 44 | + this.promptAdvisorService = promptAdvisorService; |
| 45 | + } |
| 46 | + |
27 | 47 | /** |
28 | 48 | * Get current ATPL criteria and their weights |
29 | 49 | */ |
@@ -72,54 +92,54 @@ public ResponseEntity<Map<String, Object>> validatePrompt(@RequestBody Map<Strin |
72 | 92 | * Interactive prompt refinement session |
73 | 93 | */ |
74 | 94 | @PostMapping("/refine") |
75 | | - @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) |
76 | | - public ResponseEntity<Map<String, Object>> refinePrompt(@RequestBody Map<String, Object> request) { |
77 | | - try { |
78 | | - // First validate the prompt |
79 | | - String prompt = (String) request.get("prompt"); |
80 | | - String sessionId = (String) request.getOrDefault("sessionId", UUID.randomUUID().toString()); |
81 | | - |
82 | | - Map<String, Object> validateRequest = new HashMap<>(); |
83 | | - validateRequest.put("prompt", prompt); |
84 | | - |
85 | | - // Only include context if it's a non-empty Map (prompt-advisor expects Dict or null) |
86 | | - Object contextObj = request.get("context"); |
87 | | - if (contextObj instanceof Map && !((Map<?, ?>) contextObj).isEmpty()) { |
88 | | - validateRequest.put("context", contextObj); |
89 | | - } |
90 | | - // If context is null or not provided, don't include it - let the service use its default |
91 | | - |
92 | | - String url = promptAdvisorUrl + "/validate_prompt"; |
93 | | - |
94 | | - HttpHeaders headers = new HttpHeaders(); |
95 | | - headers.setContentType(MediaType.APPLICATION_JSON); |
96 | | - |
97 | | - HttpEntity<Map<String, Object>> entity = new HttpEntity<>(validateRequest, headers); |
98 | | - |
99 | | - ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class); |
100 | | - Map<String, Object> validationResult = response.getBody(); |
101 | | - |
102 | | - // Build refinement response with suggestions |
103 | | - Map<String, Object> refinementResponse = new HashMap<>(); |
104 | | - refinementResponse.put("sessionId", sessionId); |
105 | | - refinementResponse.put("originalPrompt", prompt); |
106 | | - refinementResponse.put("score", validationResult.get("score")); |
107 | | - refinementResponse.put("ratings", validationResult.get("ratings")); |
108 | | - refinementResponse.put("explanation", validationResult.get("explanation")); |
109 | | - refinementResponse.put("recommendations", validationResult.get("recommendations")); |
110 | | - |
111 | | - // Generate refinement suggestions based on scores |
112 | | - List<String> suggestions = generateRefinementSuggestions(validationResult); |
113 | | - refinementResponse.put("suggestions", suggestions); |
114 | | - |
115 | | - return ResponseEntity.ok(refinementResponse); |
116 | | - } catch (Exception e) { |
117 | | - log.error("Error refining prompt with prompt-advisor", e); |
118 | | - Map<String, Object> error = new HashMap<>(); |
119 | | - error.put("status", "error"); |
120 | | - error.put("message", "Failed to refine prompt: " + e.getMessage()); |
121 | | - return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE).body(error); |
| 95 | + @Endpoint(description = "Refine a prompt using LLM to apply recommendations and improve quality") |
| 96 | + public ResponseEntity<?> refinePrompt( |
| 97 | + @RequestBody ValidatePromptRequest request, |
| 98 | + HttpServletRequest httpRequest, |
| 99 | + HttpServletResponse httpResponse |
| 100 | + ) { |
| 101 | + if (!systemOptions.getEnablePromptAdvisor()) { |
| 102 | + return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE) |
| 103 | + .body(Map.of("error", "Prompt advisor service is disabled")); |
122 | 104 | } |
| 105 | + |
| 106 | + var operatingUser = getOperatingUser(httpRequest, httpResponse); |
| 107 | + if (operatingUser == null) { |
| 108 | + return ResponseEntity.status(HttpStatus.UNAUTHORIZED) |
| 109 | + .body(Map.of("error", "Authentication required")); |
| 110 | + } |
| 111 | + |
| 112 | + log.info("Refining prompt using LLM for user: {}", operatingUser.getUsername()); |
| 113 | + |
| 114 | + // Use the new LLM-based refinement that actually rewrites the prompt |
| 115 | + RefinePromptResponse refineResponse = promptAdvisorService.refinePromptWithLLM( |
| 116 | + request.getPrompt(), |
| 117 | + request.getContext() |
| 118 | + ); |
| 119 | + |
| 120 | + if (refineResponse == null) { |
| 121 | + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) |
| 122 | + .body(Map.of("error", "Failed to refine prompt")); |
| 123 | + } |
| 124 | + |
| 125 | + // Build response with all refinement data |
| 126 | + Map<String, Object> result = new HashMap<>(); |
| 127 | + result.put("original_prompt", refineResponse.getOriginalPrompt()); |
| 128 | + result.put("refined_prompt", refineResponse.getRefinedPrompt()); |
| 129 | + if (refineResponse.getScore() != null) { |
| 130 | + result.put("score", refineResponse.getScore()); |
| 131 | + } |
| 132 | + if (refineResponse.getRatings() != null) { |
| 133 | + result.put("ratings", refineResponse.getRatings()); |
| 134 | + } |
| 135 | + if (refineResponse.getExplanation() != null) { |
| 136 | + result.put("explanation", refineResponse.getExplanation()); |
| 137 | + } |
| 138 | + if (refineResponse.getRecommendations() != null) { |
| 139 | + result.put("recommendations", refineResponse.getRecommendations()); |
| 140 | + } |
| 141 | + |
| 142 | + return ResponseEntity.ok(result); |
123 | 143 | } |
124 | 144 |
|
125 | 145 | /** |
|
0 commit comments