Skip to content

Commit 8f7a91a

Browse files
committed
Updates for retention score
Signed-off-by: Beat Buesser <[email protected]>
1 parent 508a0a4 commit 8f7a91a

File tree

4 files changed

+121
-114
lines changed

4 files changed

+121
-114
lines changed

art/evaluations/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from art.evaluations.great_score.great_score import GreatScorePyTorch
2+
from art.evaluations.retention_score.retention_image_score import get_retention_score_image
3+
from art.evaluations.retention_score.retention_text_score import get_retention_score_text
4+
from art.evaluations.security_curve.security_curve import SecurityCurve

art/evaluations/retention_score/retention_image_score.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,26 @@
1717
# SOFTWARE.
1818
import json
1919
import os
20-
from metric.perspective_api import PerspectiveApiScorer
21-
from metric.detoxify import DetoxifyScorer
20+
21+
# from metric.perspective_api import PerspectiveApiScorer
22+
# from metric.detoxify import DetoxifyScorer
2223
import torch
23-
from minigpt4.common.config import Config
24-
from minigpt4.common.registry import registry
25-
from minigpt4.conversation.conversation import Chat, CONV_VISION
24+
# from minigpt4.common.config import Config
25+
# from minigpt4.common.registry import registry
26+
# from minigpt4.conversation.conversation import Chat, CONV_VISION
2627
from PIL import Image
2728

2829

29-
# 假设这是调用minigpt模型进行图像评价的函数
30+
# Assume this is a function for image evaluation using the MiniGPT model
3031
def minigpt_image_eval(content):
3132
"""
32-
模拟minigpt对图像相关内容进行评价的函数
33+
Simulates a function for image evaluation using the MiniGPT model
3334
34-
:param content: 输入的图像相关内容
35-
:return: 评价结果
35+
:param content: Input image-related content
36+
:return: Evaluation result
3637
"""
37-
# 这里应该是实际调用minigpt模型的代码
38-
# 为了示例,简单返回一个字符串
38+
# This should be the actual code calling the MiniGPT model
39+
# For demonstration purposes, simply return a string
3940
return "Image evaluation result"
4041

4142

@@ -58,7 +59,7 @@ def initialize_minigpt(cfg_path, gpu_id=0):
5859
return Chat(model, vis_processor, device=f"cuda:{gpu_id}")
5960

6061

61-
def get_image_score(question_file, answer_list_files, rule_file, output_file, context_file, cfg_path, max_tokens=1024):
62+
def get_retention_score_image(question_file, answer_list_files, rule_file, output_file, context_file, cfg_path, max_tokens=1024):
6263
"""
6364
Get image scores function
6465

art/evaluations/retention_score/retention_text_score.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,19 @@
1717
# SOFTWARE.
1818
import json
1919
import os
20+
2021
import torch
21-
from minigpt4.common.config import Config
22-
from minigpt4.common.registry import registry
23-
from minigpt4.conversation.conversation import Chat, CONV_VISION
24-
from transformers import AutoTokenizer, AutoModelForCausalLM
22+
# from minigpt4.common.config import Config
23+
# from minigpt4.common.registry import registry
24+
# from minigpt4.conversation.conversation import Chat, CONV_VISION
25+
# from transformers import AutoTokenizer, AutoModelForCausalLM
2526
from tqdm import tqdm
2627

2728

2829
def initialize_minigpt(cfg_path, gpu_id=0):
2930
"""
3031
Initialize MiniGPT model
32+
3133
:param cfg_path: Path to configuration file
3234
:param gpu_id: GPU device ID
3335
:return: Initialized Chat model
@@ -58,6 +60,7 @@ def initialize_judge_model(model_path="/Llama-2-70b-chat-hf"):
5860
def extract_content(tag, text):
5961
"""
6062
Extract content from judge response
63+
6164
:param tag: Tag to search for
6265
:param text: Text to search in
6366
:return: Extracted content
@@ -113,7 +116,7 @@ def judge_response(judge_model, tokenizer, response, prefix="<s>[INST] %s[/INST]
113116
return None
114117

115118

116-
def get_text_score(question_file, answer_list_files, rule_file, output_file, cfg_path, gpu_id=0, max_tokens=1024):
119+
def get_retention_score_text(question_file, answer_list_files, rule_file, output_file, cfg_path, gpu_id=0, max_tokens=1024):
117120
"""
118121
Get text scores function
119122
Lines changed: 96 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,37 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2025
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
119
import json
220
import os
3-
import torch
21+
422
import numpy as np
523
import math
6-
from retention_image_score import get_image_score
7-
from retention_text_score import get_text_score
24+
25+
import pytest
26+
27+
from art.evaluations import get_retention_score_image, get_retention_score_text
28+
829

930

1031
def generate_synthetic_images(prompt, num_images=1):
1132
"""
1233
Placeholder function for generating synthetic images using Stable Diffusion
34+
1335
:param prompt: Image generation prompt
1436
:param num_images: Number of images to generate
1537
:return: List of generated image paths
@@ -22,6 +44,7 @@ def generate_synthetic_images(prompt, num_images=1):
2244
def paraphrase_text(prompts):
2345
"""
2446
Placeholder function for text paraphrasing using DiffuseQ
47+
2548
:param prompts: List of original prompts
2649
:return: List of paraphrased prompts
2750
"""
@@ -33,6 +56,7 @@ def paraphrase_text(prompts):
3356
def calculate_retention_score(output_file):
3457
"""
3558
Calculate the final retention score
59+
3660
:param output_file: Path to output file
3761
:return: retention score and standard deviation
3862
"""
@@ -62,18 +86,11 @@ def calculate_retention_score(output_file):
6286
return retention_score, score_std
6387

6488

65-
def get_score(
66-
question_file,
67-
answer_list_files,
68-
rule_file,
69-
output_file,
70-
mode="text",
71-
cfg_path=None,
72-
context_file=None,
73-
max_tokens=1024,
89+
def test_get_score_text(
7490
):
7591
"""
7692
General function for getting scores
93+
7794
:param question_file: Path to question file
7895
:param answer_list_files: List of paths to answer files
7996
:param rule_file: Path to rule file
@@ -84,46 +101,32 @@ def get_score(
84101
:param max_tokens: Maximum number of tokens
85102
:return: (retention_score, score_std)
86103
"""
104+
question_file = None
105+
answer_list_files = None
106+
rule_file = None
107+
output_file = None
108+
cfg_path = None
109+
max_tokens = 1024
110+
87111
if not cfg_path:
88112
cfg_path = "path/to/minigpt4_config.yaml"
89113

90114
# Read original prompts from question file
91115
with open(os.path.expanduser(question_file)) as f:
92116
questions = [json.loads(line)["text"] for line in f]
93117

94-
if mode == "image":
95-
# Note: Should call actual Stable Diffusion model for image generation
96-
synthetic_images = []
97-
for prompt in questions:
98-
images = generate_synthetic_images(prompt)
99-
synthetic_images.extend(images)
100-
101-
get_image_score(
102-
question_file=question_file,
103-
answer_list_files=answer_list_files,
104-
rule_file=rule_file,
105-
output_file=output_file,
106-
context_file=context_file,
107-
cfg_path=cfg_path,
108-
max_tokens=max_tokens,
109-
)
110-
111-
elif mode == "text":
112-
# Note: Should call actual DiffuseQ model for text paraphrasing
113-
paraphrased_prompts = paraphrase_text(questions)
114-
115-
get_text_score(
116-
question_file=question_file,
117-
answer_list_files=answer_list_files,
118-
rule_file=rule_file,
119-
output_file=output_file,
120-
cfg_path=cfg_path,
121-
gpu_id=0,
122-
max_tokens=max_tokens,
123-
)
124-
125-
else:
126-
raise ValueError("Mode must be either 'text' or 'image'")
118+
# Note: Should call actual DiffuseQ model for text paraphrasing
119+
paraphrased_prompts = paraphrase_text(questions)
120+
121+
get_retention_score_text(
122+
question_file=question_file,
123+
answer_list_files=answer_list_files,
124+
rule_file=rule_file,
125+
output_file=output_file,
126+
cfg_path=cfg_path,
127+
gpu_id=0,
128+
max_tokens=max_tokens,
129+
)
127130

128131
# Calculate final retention score
129132
retention_score, score_std = calculate_retention_score(output_file)
@@ -132,59 +135,55 @@ def get_score(
132135
return retention_score, score_std
133136

134137

135-
def parse_score(review):
138+
139+
def test_get_score_image(
140+
):
136141
"""
137-
Parse scores into float list, return [-1, -1] if parsing fails
138-
:param review: Review text
139-
:return: Score list
142+
General function for getting scores
143+
144+
:param question_file: Path to question file
145+
:param answer_list_files: List of paths to answer files
146+
:param rule_file: Path to rule file
147+
:param output_file: Path to output file
148+
:param mode: Evaluation mode ('text' or 'image')
149+
:param cfg_path: Path to MiniGPT config file
150+
:param context_file: Path to context file (for image mode)
151+
:param max_tokens: Maximum number of tokens
152+
:return: (retention_score, score_std)
140153
"""
141-
try:
142-
score_pair = review.split("\n")[0]
143-
score_pair = score_pair.replace(",", " ")
144-
sp = score_pair.split(" ")
145-
if len(sp) == 2:
146-
return [float(sp[0]), float(sp[1])]
147-
else:
148-
print("error", review)
149-
return [-1, -1]
150-
except Exception as e:
151-
print(e)
152-
print("error", review)
153-
return [-1, -1]
154-
155-
156-
if __name__ == "__main__":
157-
import argparse
158-
159-
parser = argparse.ArgumentParser(description="Calculate retention score for text or image")
160-
parser.add_argument("--question_file", type=str, required=True, help="Path to question file")
161-
parser.add_argument("--answer_files", nargs="+", required=True, help="Paths to answer files")
162-
parser.add_argument("--rule_file", type=str, required=True, help="Path to rule file")
163-
parser.add_argument("--output_file", type=str, required=True, help="Path to output file")
164-
parser.add_argument("--mode", type=str, choices=["text", "image"], default="text", help="Evaluation mode")
165-
parser.add_argument("--cfg_path", type=str, help="Path to MiniGPT config file")
166-
parser.add_argument("--context_file", type=str, help="Path to context file (required for image mode)")
167-
parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum number of tokens")
168-
169-
args = parser.parse_args()
170-
171-
# Validate context_file requirement for image mode
172-
if args.mode == "image" and not args.context_file:
173-
parser.error("--context_file is required when mode is 'image'")
174-
175-
# Calculate retention score
176-
retention_score, std = get_score(
177-
question_file=args.question_file,
178-
answer_list_files=args.answer_files,
179-
rule_file=args.rule_file,
180-
output_file=args.output_file,
181-
mode=args.mode,
182-
cfg_path=args.cfg_path,
183-
context_file=args.context_file,
184-
max_tokens=args.max_tokens,
154+
question_file = None
155+
answer_list_files = None
156+
rule_file = None
157+
output_file = None
158+
cfg_path = None
159+
context_file = None
160+
max_tokens = 1024
161+
162+
if not cfg_path:
163+
cfg_path = "path/to/minigpt4_config.yaml"
164+
165+
# Read original prompts from question file
166+
with open(os.path.expanduser(question_file)) as f:
167+
questions = [json.loads(line)["text"] for line in f]
168+
169+
# Note: Should call actual Stable Diffusion model for image generation
170+
synthetic_images = []
171+
for prompt in questions:
172+
images = generate_synthetic_images(prompt)
173+
synthetic_images.extend(images)
174+
175+
get_retention_score_image(
176+
question_file=question_file,
177+
answer_list_files=answer_list_files,
178+
rule_file=rule_file,
179+
output_file=output_file,
180+
context_file=context_file,
181+
cfg_path=cfg_path,
182+
max_tokens=max_tokens,
185183
)
186184

187-
print(f"\nFinal Results:")
188-
print(f"Mode: {args.mode}")
189-
print(f"Retention Score: {retention_score:.4f}")
190-
print(f"Standard Deviation: {std:.4f}")
185+
# Calculate final retention score
186+
retention_score, score_std = calculate_retention_score(output_file)
187+
print(f"Retention Score: {retention_score:.4f} (std: {score_std:.4f})")
188+
189+
return retention_score, score_std

0 commit comments

Comments
 (0)