1
+ # MIT License
2
+ #
3
+ # Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2025
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6
+ # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7
+ # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8
+ # persons to whom the Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11
+ # Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14
+ # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16
+ # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17
+ # SOFTWARE.
18
+
1
19
import json
2
20
import os
3
- import torch
21
+
4
22
import numpy as np
5
23
import math
6
- from retention_image_score import get_image_score
7
- from retention_text_score import get_text_score
24
+
25
+ import pytest
26
+
27
+ from art .evaluations import get_retention_score_image , get_retention_score_text
28
+
8
29
9
30
10
31
def generate_synthetic_images (prompt , num_images = 1 ):
11
32
"""
12
33
Placeholder function for generating synthetic images using Stable Diffusion
34
+
13
35
:param prompt: Image generation prompt
14
36
:param num_images: Number of images to generate
15
37
:return: List of generated image paths
@@ -22,6 +44,7 @@ def generate_synthetic_images(prompt, num_images=1):
22
44
def paraphrase_text (prompts ):
23
45
"""
24
46
Placeholder function for text paraphrasing using DiffuseQ
47
+
25
48
:param prompts: List of original prompts
26
49
:return: List of paraphrased prompts
27
50
"""
@@ -33,6 +56,7 @@ def paraphrase_text(prompts):
33
56
def calculate_retention_score (output_file ):
34
57
"""
35
58
Calculate the final retention score
59
+
36
60
:param output_file: Path to output file
37
61
:return: retention score and standard deviation
38
62
"""
@@ -62,18 +86,11 @@ def calculate_retention_score(output_file):
62
86
return retention_score , score_std
63
87
64
88
65
- def get_score (
66
- question_file ,
67
- answer_list_files ,
68
- rule_file ,
69
- output_file ,
70
- mode = "text" ,
71
- cfg_path = None ,
72
- context_file = None ,
73
- max_tokens = 1024 ,
89
+ def test_get_score_text (
74
90
):
75
91
"""
76
92
General function for getting scores
93
+
77
94
:param question_file: Path to question file
78
95
:param answer_list_files: List of paths to answer files
79
96
:param rule_file: Path to rule file
@@ -84,46 +101,32 @@ def get_score(
84
101
:param max_tokens: Maximum number of tokens
85
102
:return: (retention_score, score_std)
86
103
"""
104
+ question_file = None
105
+ answer_list_files = None
106
+ rule_file = None
107
+ output_file = None
108
+ cfg_path = None
109
+ max_tokens = 1024
110
+
87
111
if not cfg_path :
88
112
cfg_path = "path/to/minigpt4_config.yaml"
89
113
90
114
# Read original prompts from question file
91
115
with open (os .path .expanduser (question_file )) as f :
92
116
questions = [json .loads (line )["text" ] for line in f ]
93
117
94
- if mode == "image" :
95
- # Note: Should call actual Stable Diffusion model for image generation
96
- synthetic_images = []
97
- for prompt in questions :
98
- images = generate_synthetic_images (prompt )
99
- synthetic_images .extend (images )
100
-
101
- get_image_score (
102
- question_file = question_file ,
103
- answer_list_files = answer_list_files ,
104
- rule_file = rule_file ,
105
- output_file = output_file ,
106
- context_file = context_file ,
107
- cfg_path = cfg_path ,
108
- max_tokens = max_tokens ,
109
- )
110
-
111
- elif mode == "text" :
112
- # Note: Should call actual DiffuseQ model for text paraphrasing
113
- paraphrased_prompts = paraphrase_text (questions )
114
-
115
- get_text_score (
116
- question_file = question_file ,
117
- answer_list_files = answer_list_files ,
118
- rule_file = rule_file ,
119
- output_file = output_file ,
120
- cfg_path = cfg_path ,
121
- gpu_id = 0 ,
122
- max_tokens = max_tokens ,
123
- )
124
-
125
- else :
126
- raise ValueError ("Mode must be either 'text' or 'image'" )
118
+ # Note: Should call actual DiffuseQ model for text paraphrasing
119
+ paraphrased_prompts = paraphrase_text (questions )
120
+
121
+ get_retention_score_text (
122
+ question_file = question_file ,
123
+ answer_list_files = answer_list_files ,
124
+ rule_file = rule_file ,
125
+ output_file = output_file ,
126
+ cfg_path = cfg_path ,
127
+ gpu_id = 0 ,
128
+ max_tokens = max_tokens ,
129
+ )
127
130
128
131
# Calculate final retention score
129
132
retention_score , score_std = calculate_retention_score (output_file )
@@ -132,59 +135,55 @@ def get_score(
132
135
return retention_score , score_std
133
136
134
137
135
- def parse_score (review ):
138
+
139
+ def test_get_score_image (
140
+ ):
136
141
"""
137
- Parse scores into float list, return [-1, -1] if parsing fails
138
- :param review: Review text
139
- :return: Score list
142
+ General function for getting scores
143
+
144
+ :param question_file: Path to question file
145
+ :param answer_list_files: List of paths to answer files
146
+ :param rule_file: Path to rule file
147
+ :param output_file: Path to output file
148
+ :param mode: Evaluation mode ('text' or 'image')
149
+ :param cfg_path: Path to MiniGPT config file
150
+ :param context_file: Path to context file (for image mode)
151
+ :param max_tokens: Maximum number of tokens
152
+ :return: (retention_score, score_std)
140
153
"""
141
- try :
142
- score_pair = review .split ("\n " )[0 ]
143
- score_pair = score_pair .replace ("," , " " )
144
- sp = score_pair .split (" " )
145
- if len (sp ) == 2 :
146
- return [float (sp [0 ]), float (sp [1 ])]
147
- else :
148
- print ("error" , review )
149
- return [- 1 , - 1 ]
150
- except Exception as e :
151
- print (e )
152
- print ("error" , review )
153
- return [- 1 , - 1 ]
154
-
155
-
156
- if __name__ == "__main__" :
157
- import argparse
158
-
159
- parser = argparse .ArgumentParser (description = "Calculate retention score for text or image" )
160
- parser .add_argument ("--question_file" , type = str , required = True , help = "Path to question file" )
161
- parser .add_argument ("--answer_files" , nargs = "+" , required = True , help = "Paths to answer files" )
162
- parser .add_argument ("--rule_file" , type = str , required = True , help = "Path to rule file" )
163
- parser .add_argument ("--output_file" , type = str , required = True , help = "Path to output file" )
164
- parser .add_argument ("--mode" , type = str , choices = ["text" , "image" ], default = "text" , help = "Evaluation mode" )
165
- parser .add_argument ("--cfg_path" , type = str , help = "Path to MiniGPT config file" )
166
- parser .add_argument ("--context_file" , type = str , help = "Path to context file (required for image mode)" )
167
- parser .add_argument ("--max_tokens" , type = int , default = 1024 , help = "Maximum number of tokens" )
168
-
169
- args = parser .parse_args ()
170
-
171
- # Validate context_file requirement for image mode
172
- if args .mode == "image" and not args .context_file :
173
- parser .error ("--context_file is required when mode is 'image'" )
174
-
175
- # Calculate retention score
176
- retention_score , std = get_score (
177
- question_file = args .question_file ,
178
- answer_list_files = args .answer_files ,
179
- rule_file = args .rule_file ,
180
- output_file = args .output_file ,
181
- mode = args .mode ,
182
- cfg_path = args .cfg_path ,
183
- context_file = args .context_file ,
184
- max_tokens = args .max_tokens ,
154
+ question_file = None
155
+ answer_list_files = None
156
+ rule_file = None
157
+ output_file = None
158
+ cfg_path = None
159
+ context_file = None
160
+ max_tokens = 1024
161
+
162
+ if not cfg_path :
163
+ cfg_path = "path/to/minigpt4_config.yaml"
164
+
165
+ # Read original prompts from question file
166
+ with open (os .path .expanduser (question_file )) as f :
167
+ questions = [json .loads (line )["text" ] for line in f ]
168
+
169
+ # Note: Should call actual Stable Diffusion model for image generation
170
+ synthetic_images = []
171
+ for prompt in questions :
172
+ images = generate_synthetic_images (prompt )
173
+ synthetic_images .extend (images )
174
+
175
+ get_retention_score_image (
176
+ question_file = question_file ,
177
+ answer_list_files = answer_list_files ,
178
+ rule_file = rule_file ,
179
+ output_file = output_file ,
180
+ context_file = context_file ,
181
+ cfg_path = cfg_path ,
182
+ max_tokens = max_tokens ,
185
183
)
186
184
187
- print (f"\n Final Results:" )
188
- print (f"Mode: { args .mode } " )
189
- print (f"Retention Score: { retention_score :.4f} " )
190
- print (f"Standard Deviation: { std :.4f} " )
185
+ # Calculate final retention score
186
+ retention_score , score_std = calculate_retention_score (output_file )
187
+ print (f"Retention Score: { retention_score :.4f} (std: { score_std :.4f} )" )
188
+
189
+ return retention_score , score_std
0 commit comments