Skip to content

Commit 8d16c52

Browse files
pytorchbotRiandy
andauthored
Adding support to demo prompt classification with Llama Guard (#5595)
Adding support to demo prompt classification with Llama Guard (#5553) Summary: Pull Request resolved: #5553 Adding support to load Llama Guard model and run prompt classification task Reviewed By: cmodi-meta, kirklandsign Differential Revision: D63148252 fbshipit-source-id: 482559e694da05bdec75b9a2dbd76163c686e47d (cherry picked from commit 61cb5b0) Co-authored-by: Riandy Riandy <[email protected]>
1 parent 7f8ea44 commit 8d16c52

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,16 @@ public void run() {
704704
startPos,
705705
MainActivity.this,
706706
false);
707+
} else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) {
708+
String llamaGuardPromptForClassification =
709+
PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt);
710+
ETLogging.getInstance()
711+
.log("Running inference.. prompt=" + llamaGuardPromptForClassification);
712+
mModule.generate(
713+
llamaGuardPromptForClassification,
714+
llamaGuardPromptForClassification.length() + 64,
715+
MainActivity.this,
716+
false);
707717
} else {
708718
ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
709719
mModule.generate(

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ public enum ModelType {
1212
LLAMA_3,
1313
LLAMA_3_1,
1414
LLAVA_1_5,
15+
LLAMA_GUARD_3,
1516
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public static String getUserPromptTemplate(ModelType modelType) {
3333
switch (modelType) {
3434
case LLAMA_3:
3535
case LLAMA_3_1:
36+
case LLAMA_GUARD_3:
3637
return "<|start_header_id|>user<|end_header_id|>\n"
3738
+ USER_PLACEHOLDER
3839
+ "<|eot_id|>"
@@ -60,6 +61,7 @@ public static String getStopToken(ModelType modelType) {
6061
switch (modelType) {
6162
case LLAMA_3:
6263
case LLAMA_3_1:
64+
case LLAMA_GUARD_3:
6365
return "<|eot_id|>";
6466
case LLAVA_1_5:
6567
return "</s>";
@@ -72,4 +74,44 @@ public static String getLlavaPresetPrompt() {
7274
return "A chat between a curious human and an artificial intelligence assistant. The assistant"
7375
+ " gives helpful, detailed, and polite answers to the human's questions. USER: ";
7476
}
77+
78+
public static String getFormattedLlamaGuardPrompt(String userPrompt) {
79+
return getUserPromptTemplate(ModelType.LLAMA_GUARD_3)
80+
.replace(
81+
USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt));
82+
}
83+
84+
public static String getLlamaGuardPresetPrompt() {
85+
return "Task: Check if there is unsafe content in 'User' messages in conversations"
86+
+ " according our safety policy with the below categories.\n"
87+
+ "\n"
88+
+ "<BEGIN UNSAFE CONTENT CATEGORIES>\n"
89+
+ "S1: Violent Crimes. \n"
90+
+ "S2: Non-Violent Crimes. \n"
91+
+ "S3: Sex Crimes. \n"
92+
+ "S4: Child Exploitation. \n"
93+
+ "S5: Defamation. \n"
94+
+ "S6: Specialized Advice. \n"
95+
+ "S7: Privacy. \n"
96+
+ "S8: Intellectual Property. \n"
97+
+ "S9: Indiscriminate Weapons. \n"
98+
+ "S10: Hate. \n"
99+
+ "S11: Self-Harm. \n"
100+
+ "S12: Sexual Content. \n"
101+
+ "S13: Elections. \n"
102+
+ "S14: Code Interpreter Abuse.\n"
103+
+ "<END UNSAFE CONTENT CATEGORIES>\n"
104+
+ "\n"
105+
+ "<BEGIN CONVERSATION>\n"
106+
+ "\n"
107+
+ "User: "
108+
+ USER_PLACEHOLDER
109+
+ "\n"
110+
+ "<END CONVERSATION>\n"
111+
+ "Provide your safety assessment for ONLY THE LAST User message in the above"
112+
+ " conversation:\n"
113+
+ " - First line must read 'safe' or 'unsafe'.\n"
114+
+ " - If unsafe, a second line must include a comma-separated list of violated"
115+
+ " categories.";
116+
}
75117
}

0 commit comments

Comments
 (0)