-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathindex.js
More file actions
156 lines (137 loc) · 4.09 KB
/
index.js
File metadata and controls
156 lines (137 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import fs from 'fs';
import path from 'path';
import { fileURLToPath } from 'url';
import { TfidfVectorizer } from './model/tfidf.js';
import { LogisticRegression } from './model/logistic_regression.js';
const __dirname = path.dirname(fileURLToPath(import.meta.url));
// Cache for loaded models
const models = {
prompt_injection: null,
jailbreak: null,
malicious: null
};
/**
* Load a specific model by name
* @param {string} modelName - 'prompt_injection', 'jailbreak', or 'malicious'
*/
function loadModel(modelName) {
if (models[modelName]) {
return models[modelName];
}
try {
const modelPath = path.join(__dirname, 'model', `${modelName}_model.json`);
const modelData = JSON.parse(fs.readFileSync(modelPath, 'utf8'));
const vectorizer = new TfidfVectorizer(modelData);
const model = new LogisticRegression(modelData);
models[modelName] = { vectorizer, model };
return models[modelName];
} catch (error) {
throw new Error(`Failed to load ${modelName} model: ${error.message}`);
}
}
/**
* Generic check function for any model
* @param {string} prompt - The text to check
* @param {string} modelName - 'prompt_injection', 'jailbreak', or 'malicious'
*/
async function checkWithModel(prompt, modelName) {
return new Promise((resolve, reject) => {
try {
if (typeof prompt !== "string") {
return reject(new Error("Prompt must be a string"));
}
const { vectorizer, model } = loadModel(modelName);
const features = vectorizer.transform(prompt);
const prediction = model.predict(features);
const { probabilities, positiveProb } = model.predictProba(features);
resolve({
allowed: prediction === 0,
detected: prediction === 1,
prediction: prediction,
confidence: positiveProb,
probabilities: {
safe: probabilities[0],
threat: probabilities[1]
}
});
} catch (error) {
reject(error);
}
});
}
/**
* Check for prompt injection attacks
* @param {string} prompt - The text to check
*/
export function checkInjection(prompt) {
return checkWithModel(prompt, 'prompt_injection');
}
/**
* Check for jailbreak attempts
* @param {string} prompt - The text to check
*/
export function checkJailbreak(prompt) {
return checkWithModel(prompt, 'jailbreak');
}
/**
* Check for malicious content
* @param {string} prompt - The text to check
*/
export function checkMalicious(prompt) {
return checkWithModel(prompt, 'malicious');
}
/**
* Run all three checks in parallel
* @param {string} prompt - The text to check
*/
export async function checkAll(prompt) {
try {
const [injection, jailbreak, malicious] = await Promise.all([
checkInjection(prompt),
checkJailbreak(prompt),
checkMalicious(prompt)
]);
// Calculate overall risk level
const threats = [
injection.detected ? injection.confidence : 0,
jailbreak.detected ? jailbreak.confidence : 0,
malicious.detected ? malicious.confidence : 0
];
const maxThreat = Math.max(...threats);
let overallRisk = 'safe';
if (maxThreat > 0.7) overallRisk = 'high';
else if (maxThreat > 0.4) overallRisk = 'medium';
else if (maxThreat > 0) overallRisk = 'low';
// Determine which threats were detected
const threatsDetected = [];
if (injection.detected) threatsDetected.push('injection');
if (jailbreak.detected) threatsDetected.push('jailbreak');
if (malicious.detected) threatsDetected.push('malicious');
return {
injection,
jailbreak,
malicious,
allowed: injection.allowed && jailbreak.allowed && malicious.allowed,
overallRisk,
maxThreatConfidence: maxThreat,
threatsDetected
};
} catch (error) {
throw error;
}
}
/**
* Backward compatibility - defaults to injection check
* @param {string} prompt - The text to check
* @deprecated Use checkInjection() instead for clarity
*/
export function check(prompt) {
return checkInjection(prompt);
}
export default {
check,
checkInjection,
checkJailbreak,
checkMalicious,
checkAll
};