1
1
import argparse
2
2
import asyncio
3
+ import datetime
3
4
import logging
4
5
import os
5
6
import pathlib
6
7
import sys
8
+ from typing import Optional
7
9
8
10
import requests
9
11
from azure .ai .evaluation import AzureAIProject
@@ -52,7 +54,7 @@ async def callback(
52
54
return {"messages" : messages + [message ]}
53
55
54
56
55
- async def run_simulator (target_url : str , max_simulations : int ):
57
+ async def run_simulator (target_url : str , max_simulations : int , scan_name : Optional [ str ] = None ):
56
58
credential = get_azure_credential ()
57
59
azure_ai_project : AzureAIProject = {
58
60
"subscription_id" : os .getenv ("AZURE_SUBSCRIPTION_ID" ),
@@ -64,26 +66,25 @@ async def run_simulator(target_url: str, max_simulations: int):
64
66
credential = credential ,
65
67
risk_categories = [
66
68
RiskCategory .Violence ,
67
- # RiskCategory.HateUnfairness,
68
- # RiskCategory.Sexual,
69
- # RiskCategory.SelfHarm,
69
+ RiskCategory .HateUnfairness ,
70
+ RiskCategory .Sexual ,
71
+ RiskCategory .SelfHarm ,
70
72
],
71
73
num_objectives = 1 ,
72
74
)
75
+ if scan_name is None :
76
+ timestamp = datetime .datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
77
+ scan_name = f"Safety evaluation { timestamp } "
73
78
await model_red_team .scan (
74
79
target = lambda messages , stream = False , session_state = None , context = None : callback (messages , target_url ),
75
- scan_name = "Advanced-Callback-Scan" ,
80
+ scan_name = scan_name ,
76
81
attack_strategies = [
77
- AttackStrategy .EASY , # Group of easy complexity attacks
78
- # AttackStrategy.MODERATE, # Group of moderate complexity attacks
79
- # AttackStrategy.CharacterSpace, # Add character spaces
80
- # AttackStrategy.ROT13, # Use ROT13 encoding
81
- # AttackStrategy.UnicodeConfusable, # Use confusable Unicode characters
82
- # AttackStrategy.CharSwap, # Swap characters in prompts
83
- # AttackStrategy.Morse, # Encode prompts in Morse code
84
- # AttackStrategy.Leetspeak, # Use Leetspeak
85
- # AttackStrategy.Url, # Use URLs in prompts
86
- # AttackStrategy.Binary, # Encode prompts in binary
82
+ AttackStrategy .DIFFICULT ,
83
+ AttackStrategy .Baseline ,
84
+ AttackStrategy .UnicodeConfusable , # Use confusable Unicode characters
85
+ AttackStrategy .Morse , # Encode prompts in Morse code
86
+ AttackStrategy .Leetspeak , # Use Leetspeak
87
+ AttackStrategy .Url , # Use URLs in prompts
87
88
],
88
89
output_path = "Advanced-Callback-Scan.json" ,
89
90
)
@@ -97,28 +98,29 @@ async def run_simulator(target_url: str, max_simulations: int):
97
98
parser .add_argument (
98
99
"--max_simulations" , type = int , default = 200 , help = "Maximum number of simulations (question/response pairs)."
99
100
)
101
+ # argument for the name
102
+ parser .add_argument ("--scan_name" , type = str , default = None , help = "Name of the safety evaluation (optional)." )
100
103
args = parser .parse_args ()
101
104
102
105
# Configure logging to show tracebacks for warnings and above
103
106
logging .basicConfig (
104
- level = logging .DEBUG ,
107
+ level = logging .WARNING ,
105
108
format = "%(message)s" ,
106
109
datefmt = "[%X]" ,
107
110
handlers = [RichHandler (rich_tracebacks = False , show_path = True )],
108
111
)
109
112
110
113
# Set urllib3 and azure libraries to WARNING level to see connection issues
111
114
logging .getLogger ("urllib3" ).setLevel (logging .WARNING )
112
- logging .getLogger ("azure" ).setLevel (logging .DEBUG )
113
- logging .getLogger ("RedTeamLogger" ).setLevel (logging .DEBUG )
115
+ logging .getLogger ("azure" ).setLevel (logging .WARNING )
114
116
115
117
# Set our application logger to INFO level
116
118
logger .setLevel (logging .INFO )
117
119
118
120
load_azd_env ()
119
121
120
122
try :
121
- asyncio .run (run_simulator (args .target_url , args .max_simulations ))
123
+ asyncio .run (run_simulator (args .target_url , args .max_simulations , args . scan_name ))
122
124
except Exception :
123
125
logging .exception ("Unhandled exception in safety evaluation" )
124
126
sys .exit (1 )
0 commit comments