Skip to content

Commit 9506719

Browse files
This update adds the following
Fixes a bug where vulnerabilities were not properly resolved if there was no linked http request. Updates Spring AI Version Removes unused Prompts Adds Hints to returned vulnerabilities to guide the Agent on the correct fix. Adds Library CVE data to the stacktrace. Which gives the agent a hint as to the fix if the vulnerability actually is occuring in a vulnerability within a 3rd party library. By giving the CVEs for that library.
1 parent 79ab08d commit 9506719

19 files changed

+863
-163
lines changed

pom.xml

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
<?xml version="1.0" encoding="UTF-8"?>
22
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
3-
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
3+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
44
<modelVersion>4.0.0</modelVersion>
55
<parent>
66
<groupId>org.springframework.boot</groupId>
77
<artifactId>spring-boot-starter-parent</artifactId>
8-
<version>3.4.4</version>
8+
<version>3.4.5</version>
99
<relativePath/> <!-- lookup parent from repository -->
1010
</parent>
1111
<groupId>com.contrast.labs.ai.mcp</groupId>
@@ -28,22 +28,22 @@
2828
</scm>
2929
<properties>
3030
<java.version>17</java.version>
31-
<spring-ai.version>1.0.0-M6</spring-ai.version>
31+
<spring-ai.version>1.0.0-RC1</spring-ai.version>
3232
</properties>
33-
<dependencies>
34-
<dependency>
35-
<groupId>com.google.guava</groupId>
36-
<artifactId>guava</artifactId>
37-
<version>33.4.7-jre</version> <!-- Use the latest stable version -->
38-
</dependency>
39-
<dependency>
40-
<groupId>com.contrastsecurity</groupId>
41-
<artifactId>contrast-sdk-java</artifactId>
42-
<version>3.4.2</version>
43-
</dependency>
33+
<dependencies>
34+
<dependency>
35+
<groupId>com.google.guava</groupId>
36+
<artifactId>guava</artifactId>
37+
<version>33.4.7-jre</version> <!-- Use the latest stable version -->
38+
</dependency>
39+
<dependency>
40+
<groupId>com.contrastsecurity</groupId>
41+
<artifactId>contrast-sdk-java</artifactId>
42+
<version>3.4.2</version>
43+
</dependency>
4444
<dependency>
4545
<groupId>org.springframework.ai</groupId>
46-
<artifactId>spring-ai-mcp-server-spring-boot-starter</artifactId>
46+
<artifactId>spring-ai-starter-mcp-server</artifactId>
4747
</dependency>
4848

4949
<dependency>
@@ -73,4 +73,4 @@
7373
</plugins>
7474
</build>
7575

76-
</project>
76+
</project>

src/main/java/com/contrast/labs/ai/mcp/contrast/AssessService.java

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@
1717

1818

1919
import com.contrast.labs.ai.mcp.contrast.data.ApplicationData;
20+
import com.contrast.labs.ai.mcp.contrast.data.LibraryLibraryObservation;
21+
import com.contrast.labs.ai.mcp.contrast.data.StackLib;
2022
import com.contrast.labs.ai.mcp.contrast.data.VulnLight;
2123
import com.contrast.labs.ai.mcp.contrast.data.Vulnerability;
24+
import com.contrast.labs.ai.mcp.contrast.hints.HintGenerator;
25+
import com.contrast.labs.ai.mcp.contrast.sdkexstension.SDKExtension;
2226
import com.contrast.labs.ai.mcp.contrast.sdkexstension.SDKHelper;
27+
import com.contrast.labs.ai.mcp.contrast.sdkexstension.data.LibraryExtended;
28+
import com.contrast.labs.ai.mcp.contrast.sdkexstension.data.sca.LibraryObservation;
2329
import com.contrastsecurity.models.Application;
2430
import com.contrastsecurity.models.EventResource;
2531
import com.contrastsecurity.models.EventSummaryResponse;
@@ -37,13 +43,16 @@
3743

3844
import java.io.IOException;
3945
import java.util.ArrayList;
46+
import java.util.HashSet;
4047
import java.util.List;
4148
import java.util.Optional;
49+
import java.util.Set;
4250

4351
@Service
4452
public class AssessService {
4553

4654
private static final Logger logger = LoggerFactory.getLogger(AssessService.class);
55+
4756

4857
@Value("${contrast.host-name:${CONTRAST_HOST_NAME:}}")
4958
private String hostName;
@@ -61,7 +70,7 @@ public class AssessService {
6170
private String orgID;
6271

6372

64-
@Tool(name = "get_vulnerability", description = "takes a vulnerability ID ( vulnID ) and Application ID ( appID ) and returns details about the specific security vulnerability")
73+
@Tool(name = "get_vulnerability", description = "takes a vulnerability ID ( vulnID ) and Application ID ( appID ) and returns details about the specific security vulnerability. If based on the stacktrace, the vulnerability looks like it is in code that is not in the codebase, the vulnerability may be in a 3rd party library, review the CVE data attached to that stackframe you believe the vulnerability exists in and if possible upgrade that library to the next non vulnerable version based on the remediation guidance.")
6574
public Vulnerability getVulnerability(String vulnID, String appID) throws IOException {
6675
logger.info("Retrieving vulnerability details for vulnID: {} in application ID: {}", vulnID, appID);
6776
ContrastSDK contrastSDK = SDKHelper.getSDK(hostName, apiKey, serviceKey, userName);
@@ -90,25 +99,64 @@ public Vulnerability getVulnerability(String vulnID, String appID) throws IOExce
9099
logger.debug("Found {} stack traces for vulnerability", stackTraces.size());
91100
}
92101
}
93-
102+
List<LibraryExtended> libs = SDKHelper.getLibsForID(appID,orgID, new SDKExtension(contrastSDK));
103+
List<LibraryLibraryObservation> lobs = new ArrayList<>();
104+
for(LibraryExtended lib : libs) {
105+
LibraryLibraryObservation llob = new LibraryLibraryObservation(lib,new SDKExtension(contrastSDK).getLibraryObservations(orgID,appID,lib.getHash(),50));
106+
lobs.add(llob);
107+
}
108+
List<StackLib> stackLibs = new ArrayList<>();
109+
Set<LibraryExtended> libsToReturn = new HashSet<>();
110+
for(String stackTrace : stackTraces) {
111+
Optional<LibraryLibraryObservation> matchingLlobOpt = findMatchingLibraryData(stackTrace, lobs);
112+
if (matchingLlobOpt.isPresent()) {
113+
LibraryLibraryObservation llob = matchingLlobOpt.get();
114+
LibraryExtended library = llob.library();
115+
if (!library.getVulnerabilities().isEmpty()) {
116+
libsToReturn.add(library); // Set.add() handles uniqueness efficiently
117+
stackLibs.add(new StackLib(stackTrace, library.getHash()));
118+
} else {
119+
stackLibs.add(new StackLib(stackTrace, null));
120+
}
121+
} else {
122+
stackLibs.add(new StackLib(stackTrace, null));
123+
}
124+
}
125+
126+
String httpRequestText = null;
127+
if( requestResponse.getHttpRequest()!=null) {
128+
httpRequestText = requestResponse.getHttpRequest().getText();
129+
}
130+
String hint = HintGenerator.generateVulnerabilityFixHint(trace.getRule());
94131
logger.info("Successfully retrieved vulnerability details for vulnID: {}", vulnID);
95-
return new Vulnerability(vulnID, trace.getTitle(), trace.getRule(),
96-
recommendationResponse.getRecommendation().getText(), stackTraces,
97-
requestResponse.getHttpRequest().getText());
132+
return new Vulnerability(hint, vulnID, trace.getTitle(), trace.getRule(),
133+
recommendationResponse.getRecommendation().getText(), stackLibs, new ArrayList<>(libsToReturn), // Convert Set to List
134+
httpRequestText);
98135
} catch (Exception e) {
99136
logger.error("Error retrieving vulnerability details for vulnID: {}", vulnID, e);
100137
throw new IOException("Failed to retrieve vulnerability details: " + e.getMessage(), e);
101138
}
102139
}
103140

141+
private Optional<LibraryLibraryObservation> findMatchingLibraryData(String stackTrace, List<LibraryLibraryObservation> lobs) {
142+
String lowerStackTrace = stackTrace.toLowerCase();
143+
for (LibraryLibraryObservation llob : lobs) {
144+
for (LibraryObservation lob : llob.libraryObservation()) {
145+
if (lob.getName() != null && lowerStackTrace.startsWith(lob.getName().toLowerCase())) {
146+
return Optional.of(llob);
147+
}
148+
}
149+
}
150+
return Optional.empty();
151+
}
104152

105-
@Tool(name = "get_vulnerability_by_app_name", description = "Takes a vulnerability ID (vulnID) and application name (appName) and returns details about the specific security vulnerability")
153+
@Tool(name = "get_vulnerability_by_app_name", description = "Takes a vulnerability ID (vulnID) and application name (appName) and returns details about the specific security vulnerability. If based on the stacktrace, the vulnerability looks like it is in code that is not in the codebase, the vulnerability may be in a 3rd party library, review the CVE data attached to that stackframe you believe the vulnerability exists in and if possible upgrade that library to the next non vulnerable version based on the remediation guidance.")
106154
public Vulnerability getVulnerabilityByAppName(String vulnID, String appName) throws IOException {
107155
logger.info("Retrieving vulnerability details for vulnID: {} in application: {}", vulnID, appName);
108156
ContrastSDK contrastSDK = SDKHelper.getSDK(hostName, apiKey, serviceKey, userName);
109157
Optional<String> appID = Optional.empty();
110158
logger.debug("Searching for application ID matching name: {}", appName);
111-
159+
112160
for(Application app : contrastSDK.getApplications(orgID).getApplications()) {
113161
if(app.getName().toLowerCase().contains(appName.toLowerCase())) {
114162
appID = Optional.of(app.getId());
@@ -117,38 +165,7 @@ public Vulnerability getVulnerabilityByAppName(String vulnID, String appName) th
117165
}
118166
}
119167
if(appID.isPresent()) {
120-
try {
121-
Trace trace = contrastSDK.getTraces(orgID, appID.get(), new TraceFilterBody()).getTraces().stream()
122-
.filter(t -> t.getUuid().toLowerCase().equals(vulnID.toLowerCase()))
123-
.findFirst()
124-
.orElseThrow();
125-
logger.debug("Found trace with title: {} and rule: {}", trace.getTitle(), trace.getRule());
126-
127-
RecommendationResponse recommendationResponse = contrastSDK.getRecommendation(orgID, vulnID);
128-
HttpRequestResponse requestResponse = contrastSDK.getHttpRequest(orgID, vulnID);
129-
EventSummaryResponse eventSummaryResponse = contrastSDK.getEventSummary(orgID, vulnID);
130-
131-
Optional<EventResource> triggerEvent = eventSummaryResponse.getEvents().stream()
132-
.filter(e -> e.getType().equalsIgnoreCase("trigger"))
133-
.findFirst();
134-
135-
List<String> stackTraces = new ArrayList<>();
136-
if (triggerEvent.isPresent()) {
137-
List<Stacktrace> sTrace = triggerEvent.get().getEvent().getStacktraces();
138-
if (sTrace != null) {
139-
stackTraces.addAll(sTrace.stream().map(Stacktrace::getDescription).toList());
140-
logger.debug("Found {} stack traces for vulnerability", stackTraces.size());
141-
}
142-
}
143-
144-
logger.info("Successfully retrieved vulnerability details for vulnID: {} in app: {}", vulnID, appName);
145-
return new Vulnerability(vulnID, trace.getTitle(), trace.getRule(),
146-
recommendationResponse.getRecommendation().getText(), stackTraces,
147-
requestResponse.getHttpRequest().getText());
148-
} catch (Exception e) {
149-
logger.error("Error retrieving vulnerability details for vulnID: {} in app: {}", vulnID, appName, e);
150-
throw new IOException("Failed to retrieve vulnerability details: " + e.getMessage(), e);
151-
}
168+
return getVulnerability(vulnID, appID.get());
152169
} else {
153170
logger.error("Application with name {} not found", appName);
154171
throw new IllegalArgumentException("Application with name " + appName + " not found");
@@ -166,7 +183,7 @@ public List<VulnLight> listVulnsInApp(String appID) throws IOException {
166183

167184
List<VulnLight> vulns = new ArrayList<>();
168185
for(Trace trace : traces) {
169-
vulns.add(new VulnLight(trace.getTitle(), trace.getRule(), trace.getUuid()));
186+
vulns.add(new VulnLight(trace.getTitle(), trace.getRule(), trace.getUuid(),trace.getSeverity()));
170187
}
171188

172189
logger.info("Successfully retrieved {} vulnerabilities for application ID: {}", vulns.size(), appID);
@@ -177,7 +194,7 @@ public List<VulnLight> listVulnsInApp(String appID) throws IOException {
177194
}
178195
}
179196

180-
@Tool(name = "list_vulnerabilities_with_app_name", description = "Takes an application name ( appName ) and returns a list of vulnerabilities, please remember to include the vulnID in the response.")
197+
@Tool(name = "list_vulnerabilities_with_app_name", description = "Takes an application name ( appName ) and returns a list of vulnerabilities, please remember to include the vulnID in the response. ")
181198
public List<VulnLight> listVulnsInAppByName(String appName) throws IOException {
182199
logger.info("Listing vulnerabilities for application: {}", appName);
183200
ContrastSDK contrastSDK = SDKHelper.getSDK(hostName, apiKey, serviceKey, userName);
@@ -192,19 +209,9 @@ public List<VulnLight> listVulnsInAppByName(String appName) throws IOException {
192209
break;
193210
}
194211
}
195-
196212
if(appID.isPresent()) {
197213
try {
198-
List<Trace> traces = contrastSDK.getTraces(orgID, appID.get(), new TraceFilterBody()).getTraces();
199-
logger.debug("Found {} vulnerability traces for application: {}", traces.size(), appName);
200-
201-
List<VulnLight> vulns = new ArrayList<>();
202-
for (Trace trace : traces) {
203-
vulns.add(new VulnLight(trace.getTitle(), trace.getRule(), trace.getUuid()));
204-
}
205-
206-
logger.info("Successfully retrieved {} vulnerabilities for application: {}", vulns.size(), appName);
207-
return vulns;
214+
return listVulnsInApp(appID.get());
208215
} catch (Exception e) {
209216
logger.error("Error listing vulnerabilities for application: {}", appName, e);
210217
throw new IOException("Failed to list vulnerabilities: " + e.getMessage(), e);

src/main/java/com/contrast/labs/ai/mcp/contrast/McpContrastApplication.java

Lines changed: 1 addition & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,14 @@
1515
*/
1616
package com.contrast.labs.ai.mcp.contrast;
1717

18-
import io.modelcontextprotocol.server.McpServerFeatures;
19-
import io.modelcontextprotocol.spec.McpSchema;
18+
import org.springframework.ai.support.ToolCallbacks;
2019
import org.springframework.ai.tool.ToolCallback;
21-
import org.springframework.ai.tool.ToolCallbacks;
2220
import org.springframework.boot.SpringApplication;
2321
import org.springframework.boot.autoconfigure.SpringBootApplication;
2422
import org.springframework.context.annotation.Bean;
2523

26-
import java.util.Collections;
2724
import java.util.List;
2825

29-
import static com.contrast.labs.ai.mcp.contrast.PromptRegistration.PROMPT_TEMPLATE;
3026
import static java.util.List.of;
3127

3228
@SpringBootApplication
@@ -36,96 +32,9 @@ public static void main(String[] args) {
3632
SpringApplication.run(McpContrastApplication.class, args);
3733
}
3834

39-
4035
@Bean
4136
public List<ToolCallback> tools(AssessService assessService, SastService sastService,SCAService scaService,ADRService adrService) {
4237
return of(ToolCallbacks.from(assessService,sastService,scaService,adrService));
4338
}
4439

45-
@Bean
46-
public List<McpServerFeatures.SyncPromptRegistration> prompts() {
47-
// AssessService prompts
48-
var getVulnPrompt = new McpSchema.Prompt("get_vulnerability_prompt", "Get details about an Assess vulnerability",
49-
List.of(
50-
new McpSchema.PromptArgument("vulnID", "The vulnerability ID", true),
51-
new McpSchema.PromptArgument("appID", "The application ID", true)
52-
));
53-
54-
var getVulnByAppNamePrompt = new McpSchema.Prompt("get_vulnerability_by_app_name_prompt", "Get details about an Assess vulnerability",
55-
List.of(
56-
new McpSchema.PromptArgument("vulnID", "The vulnerability ID", true),
57-
new McpSchema.PromptArgument("appName", "The application name", true)
58-
));
59-
60-
var listVulnsPrompt = new McpSchema.Prompt("list_vulnerabilities_prompt", "List vulnerabilities for an application",
61-
List.of(new McpSchema.PromptArgument("appID", "The application ID", true)));
62-
63-
var listVulnsByAppNamePrompt = new McpSchema.Prompt("list_vulnerabilities_with_app_name_prompt", "List vulnerabilities by app name",
64-
List.of(new McpSchema.PromptArgument("appName", "The application name", true)));
65-
66-
var listAppsPrompt = new McpSchema.Prompt("list_applications_prompt", "List active applications by name",
67-
List.of(new McpSchema.PromptArgument("appName", "The application name to filter by", true)));
68-
69-
// SastService prompts
70-
var scanProjectPrompt = new McpSchema.Prompt("list_scan_project_prompt", "Get scan project details",
71-
List.of(new McpSchema.PromptArgument("projectName", "The scan project name", true)));
72-
73-
var scanResultsPrompt = new McpSchema.Prompt("list_scan_results_prompt", "Get latest scan results for project name",
74-
List.of(new McpSchema.PromptArgument("projectName", "The scan project name", true)));
75-
76-
// SCAService prompts
77-
var appLibsByIdPrompt = new McpSchema.Prompt("list_application_libraries_by_app_id_prompt", "Get libraries used by app ID",
78-
List.of(new McpSchema.PromptArgument("appID", "The application ID", true)));
79-
80-
var appLibsPrompt = new McpSchema.Prompt("list_application_libraries_prompt", "Get libraries used by app name",
81-
List.of(new McpSchema.PromptArgument("appName", "The application name", true)));
82-
83-
var cveAppsPrompt = new McpSchema.Prompt("list_applications_vulnerable_to_cve_prompt", "Find apps vulnerable to a CVE",
84-
List.of(new McpSchema.PromptArgument("cveid", "The CVE ID", true)));
85-
86-
// ADRService prompts
87-
var protectRulesPrompt = new McpSchema.Prompt("get_adr_protect_rules_prompt", "Get protect/ADR rules by app name",
88-
List.of(new McpSchema.PromptArgument("applicationName", "The application name", true)));
89-
90-
var protectRulesByIdPrompt = new McpSchema.Prompt("get_adr_protect_rules_by_app_id_prompt", "Get protect/ADR rules by app ID",
91-
List.of(new McpSchema.PromptArgument("appID", "The application ID", true)));
92-
93-
// Create generic message handler for all prompts
94-
return List.of(
95-
getAssistantPrompt(),
96-
createToolPromptRegistration(getVulnPrompt, "get_vulnerability"),
97-
createToolPromptRegistration(getVulnByAppNamePrompt, "get_vulnerability_by_app_name"),
98-
createToolPromptRegistration(listVulnsPrompt, "list_vulnerabilities"),
99-
createToolPromptRegistration(listVulnsByAppNamePrompt, "list_vulnerabilities_with_app_name"),
100-
createToolPromptRegistration(listAppsPrompt, "list_applications"),
101-
createToolPromptRegistration(scanProjectPrompt, "list_Scan_Project"),
102-
createToolPromptRegistration(scanResultsPrompt, "list_Scan_Results"),
103-
createToolPromptRegistration(appLibsByIdPrompt, "list_application_libraries_by_app_id"),
104-
createToolPromptRegistration(appLibsPrompt, "list_application_libraries"),
105-
createToolPromptRegistration(cveAppsPrompt, "list_applications_vulnerable_to_cve"),
106-
createToolPromptRegistration(protectRulesPrompt, "get_ADR_Protect_Rules"),
107-
createToolPromptRegistration(protectRulesByIdPrompt, "get_ADR_Protect_Rules_by_app_id")
108-
);
109-
}
110-
111-
112-
113-
private McpServerFeatures.SyncPromptRegistration getAssistantPrompt() {
114-
var prompt = new McpSchema.Prompt("default-contrast-prompt", "A prompt to seed system prompt for Contrast chat with mcp.", Collections.emptyList());
115-
return new McpServerFeatures.SyncPromptRegistration(prompt, getPromptRequest -> {
116-
var promptMessage = new McpSchema.PromptMessage(McpSchema.Role.ASSISTANT,
117-
new McpSchema.TextContent(PROMPT_TEMPLATE));
118-
return new McpSchema.GetPromptResult("default-contrast-prompt", List.of(promptMessage));
119-
});
120-
}
121-
122-
123-
// Helper method to create tool prompt registrations with consistent formatting
124-
private McpServerFeatures.SyncPromptRegistration createToolPromptRegistration(McpSchema.Prompt prompt, String toolName) {
125-
return new McpServerFeatures.SyncPromptRegistration(prompt, getPromptRequest -> {
126-
var userMessage = new McpSchema.PromptMessage(McpSchema.Role.USER,
127-
new McpSchema.TextContent("Please use the " + toolName + " tool with the following parameters: " + getPromptRequest.arguments()));
128-
return new McpSchema.GetPromptResult("A prompt to use the " + toolName + " tool", List.of(userMessage));
129-
});
130-
}
13140
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package com.contrast.labs.ai.mcp.contrast;
2+
3+
import org.springframework.stereotype.Service;
4+
5+
@Service
6+
public class PromptService {
7+
8+
9+
10+
}

0 commit comments

Comments
 (0)