Skip to content

Commit 9a58c61

Browse files
committed
Fixup issue with regexes
1 parent 41ac5e8 commit 9a58c61

File tree

4 files changed

+59
-41
lines changed

4 files changed

+59
-41
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE EXTENSION IF NOT EXISTS pg_trgm;
2+
3+
CREATE INDEX idx_command_pattern_trgm ON command_categories USING gin (pattern gin_trgm_ops);

core/src/main/java/io/sentrius/sso/core/repository/CommandCategoryRepository.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import io.sentrius.sso.core.model.categorization.CommandCategory;
55
import org.springframework.data.jpa.repository.JpaRepository;
66
import org.springframework.data.jpa.repository.Query;
7+
import org.springframework.data.repository.query.Param;
78
import org.springframework.stereotype.Repository;
89

910
@Repository
@@ -12,4 +13,8 @@ public interface CommandCategoryRepository extends JpaRepository<CommandCategory
1213
List<CommandCategory> findAllOrderedByPriority();
1314

1415
List<CommandCategory> findByPattern(String pattern);
16+
17+
@Query(value = "SELECT * FROM command_categories WHERE :command ~ pattern", nativeQuery = true)
18+
List<CommandCategory> findMatchingCategories(@Param("command") String command);
19+
1520
}

core/src/main/java/io/sentrius/sso/core/services/openai/categorization/CommandCategorizer.java

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package io.sentrius.sso.core.services.openai.categorization;
22

3+
import java.util.Comparator;
34
import java.util.List;
45
import java.util.concurrent.TimeUnit;
56
import java.util.regex.Pattern;
7+
import java.util.regex.PatternSyntaxException;
68
import com.fasterxml.jackson.core.JsonProcessingException;
79
import com.github.benmanes.caffeine.cache.Cache;
810
import com.github.benmanes.caffeine.cache.Caffeine;
@@ -33,42 +35,44 @@ public class CommandCategorizer {
3335
private final CommandCategoryRepository commandCategoryRepository;
3436

3537

36-
private final CommandTrie commandTrie = new CommandTrie();
37-
3838
private final Cache<String, CommandCategory> commandCache = Caffeine.newBuilder()
39-
.maximumSize(10000)
40-
.expireAfterWrite(24, TimeUnit.HOURS)
39+
.maximumSize(1000)
40+
.expireAfterWrite(1, TimeUnit.HOURS)
4141
.build();
4242

43-
@PostConstruct
44-
public void initializeTrie() {
45-
List<CommandCategory> categories = commandCategoryRepository.findAll();
46-
for (CommandCategory category : categories) {
47-
log.info("Adding command category {} to trie", category);
48-
commandTrie.insert(category.getPattern(), category);
49-
}
43+
44+
45+
private CommandCategory fetchFromDatabase(String command) {
46+
List<CommandCategory> matchingCategories = commandCategoryRepository.findMatchingCategories(command);
47+
return matchingCategories.stream()
48+
.min(Comparator.comparingInt(CommandCategory::getPriority))
49+
.orElse(null);
5050
}
5151

52+
5253
@Transactional
5354
public CommandCategory categorizeCommand(String command) {
5455
return commandCache.get(command, this::categorizeWithRulesOrML);
5556
}
5657

57-
5858
protected List<CommandCategory> getDBCommandCategory(String command){
5959
return commandCategoryRepository.findByPattern(command);
6060
}
6161

62-
@Transactional
63-
protected void addCommandCategory(String command, CommandCategory category) {
64-
commandTrie.insert(command, category);
65-
commandCategoryRepository.save(category);
62+
63+
public boolean isValidRegex(String regex) {
64+
try {
65+
Pattern.compile(regex);
66+
return true; // Valid regex
67+
} catch (PatternSyntaxException e) {
68+
return false; // Invalid regex
69+
}
6670
}
6771

6872

6973
@Transactional
7074
protected CommandCategory categorizeWithRulesOrML(String command) {
71-
CommandCategory category = commandTrie.searchByPrefix(command);
75+
CommandCategory category = fetchFromDatabase(command);
7276
if (category != null) {
7377
log.info("Found command category {} for {} ", category, command);
7478
return category;
@@ -93,8 +97,9 @@ protected CommandCategory categorizeWithRulesOrML(String command) {
9397
try {
9498
category = commandCategorizer.generate(command);
9599

96-
addCommandCategory(category.getPattern(), category);
97-
100+
if (isValidRegex(category.getPattern())) {
101+
addCommandCategory(category.getPattern(), category);
102+
}
98103
log.info("Categorized command: {}", category);
99104
return category;
100105
} catch (Exception e) {
@@ -110,4 +115,9 @@ protected CommandCategory categorizeWithRulesOrML(String command) {
110115

111116
return CommandCategory.builder().build();
112117
}
118+
119+
private void addCommandCategory(String pattern, CommandCategory category) {
120+
commandCategoryRepository.save(category);
121+
commandCache.put(pattern, category);
122+
}
113123
}

core/src/main/java/io/sentrius/sso/genai/LLMCommandCategorizer.java

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -58,28 +58,28 @@ public CommandCategory generate(String on) throws HttpException, JsonProcessingE
5858
public String generateInput(String on) {
5959
return """
6060
Categorize the following command with a generalized pattern, defined as a **regex**, that captures the intent and considers risk factors. Include specific arguments or paths in the regex only if they significantly impact the risk level. If no regex is predefined for the command, generate one that appropriately generalizes the command's behavior while retaining any risk-relevant components.
61-
62-
For example:
63-
- 'cat /etc/passwd' is risky due to sensitive user data, so it should be included as-is with the regex '^cat /etc/passwd$'.
64-
- 'cat /etc/hosts' has low risk, so it should be generalized to '^cat /etc/.*$' to cover all files in the `/etc` directory.
65-
- 'sudo rm -rf /important_dir' should include '/important_dir' if it's sensitive, but otherwise generalized to '^sudo rm -rf .*'.
66-
67-
Command: "%s"
68-
69-
**Categories to choose from:**
70-
- PRIVILEGED: Commands that require elevated permissions or pose a risk if misused.
71-
- DESTRUCTIVE: Commands that can delete or alter critical files or system configurations.
72-
- INFORMATIONAL: Commands that retrieve information without altering the system state.
73-
- GENERAL: Commands that do not fit into the above categories.
74-
75-
Respond in **JSON format** as follows:
76-
{
77-
"category": "<Category Name>",
78-
"priority": <Numerical Priority>,
79-
"pattern": "<Generalized regex Pattern>",
80-
"rationale": "<Explain why this category and regex were chosen>"
81-
}
82-
61+
62+
For example:
63+
- 'cat /etc/passwd' is risky due to sensitive user data, so it should be included as-is with the regex '^cat /etc/passwd$'.
64+
- 'cat /etc/hosts' has low risk, so it should be generalized to '^cat /etc/.*$' to cover all files in the `/etc` directory.
65+
- 'sudo rm -rf /important_dir' should include '/important_dir' if it's sensitive, but otherwise generalized to '^sudo rm -rf .*'.
66+
67+
Command: "%s"
68+
69+
**Categories to choose from:**
70+
- PRIVILEGED: Commands that require elevated permissions or pose a risk if misused.
71+
- DESTRUCTIVE: Commands that can delete or alter critical files or system configurations.
72+
- INFORMATIONAL: Commands that retrieve information without altering the system state.
73+
- GENERAL: Commands that do not fit into the above categories.
74+
75+
Respond in **JSON format** as follows:
76+
{
77+
"category": "<Category Name>",
78+
"priority": <Numerical Priority>,
79+
"pattern": "<Generalized regex Pattern>",
80+
"rationale": "<Explain why this category and regex were chosen>"
81+
}
82+
8383
""".formatted(on);
8484
}
8585

0 commit comments

Comments
 (0)