Skip to content

Commit 0b954e9

Browse files
Merge branch 'develop' into feat/progress_bar
2 parents d8224a9 + 093369c commit 0b954e9

File tree

2 files changed

+139
-1
lines changed

2 files changed

+139
-1
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
haskell: true
5151
large-packages: true
5252
docker-images: true
53-
swap-storage: true
53+
swap-storage: true
5454

5555
- name: Set up Python
5656
uses: actions/setup-python@v6.0.0

tests/test_adversarial.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2024 CVS Health and/or one of its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import json
17+
import os
18+
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
19+
20+
import pytest
21+
22+
from langfair.generator.redteaming import (
23+
INSTRUCTION_DICT,
24+
AdversarialGenerator,
25+
)
26+
27+
28+
@pytest.fixture
29+
def mock_llm():
30+
llm = MagicMock()
31+
llm.temperature = 0.7
32+
return llm
33+
34+
35+
@pytest.fixture
36+
def generator(mock_llm):
37+
gen = AdversarialGenerator(langchain_llm=mock_llm)
38+
gen.count = 1
39+
return gen
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_generate_from_template_valid(generator):
44+
prompts = {"text": ["Alice is a", "Bob is a"]}
45+
generator.generate_responses = AsyncMock(
46+
return_value={"data": {"response": ["Alice is a teacher", "Bob is a doctor"]}}
47+
)
48+
result = await generator._generate_from_template(prompts, ["benign"], count=1)
49+
assert "benign_response" in result
50+
assert len(result["benign_response"]) == 2
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_generate_from_template_invalid_style(generator):
55+
with pytest.raises(ValueError):
56+
await generator._generate_from_template(
57+
{"text": ["test"]}, ["invalid"], count=1
58+
)
59+
60+
61+
def test_format_result_structure(generator):
62+
dataset = {
63+
"benign_response": ["ok", "Unable to get response"],
64+
"adversarial_response": ["Unable to get response", "Unable to get response"],
65+
}
66+
prompt_templates = {
67+
"text": ["X is Y", "Z is W"],
68+
"group": ["A", "B"],
69+
"group_category": ["Race", "Gender"],
70+
}
71+
result = generator._format_result(
72+
dataset, prompt_templates, keys=["text", "group", "group_category"]
73+
)
74+
assert "data" in result
75+
assert "metadata" in result
76+
assert result["data"]["prompt"] == ["X is Y", "Z is W"]
77+
assert result["metadata"]["benign_response_non_completion_rate"] == 0.5
78+
assert result["metadata"]["adversarial_response_non_completion_rate"] == 1.0
79+
80+
81+
def test_read_counterfactual_data_valid():
82+
mock_json = json.dumps(
83+
[
84+
{
85+
"text": "X is Y",
86+
"group": "Homosexuals",
87+
"group_category": "Sexual orientation",
88+
}
89+
]
90+
)
91+
with (
92+
patch("builtins.open", mock_open(read_data=mock_json)),
93+
patch(
94+
"pkgutil.resolve_name",
95+
return_value=MagicMock(__file__=os.path.join("langfair", "__init__.py")),
96+
),
97+
):
98+
result = AdversarialGenerator._read_counterfactual_data(["Sexual orientation"])
99+
assert "text" in result
100+
assert result["group"] == ["Homosexuals"]
101+
assert result["group_category"] == ["Sexual orientation"]
102+
103+
104+
def test_read_counterfactual_data_invalid():
105+
with pytest.raises(ValueError):
106+
AdversarialGenerator._read_counterfactual_data(["InvalidGroup"])
107+
108+
109+
def test_read_toxicity_data_toxic():
110+
mock_lines = [
111+
json.dumps({"prompt": {"text": "test", "toxicity": 0.9}, "challenging": True})
112+
]
113+
with (
114+
patch("builtins.open", mock_open(read_data="\n".join(mock_lines))),
115+
patch(
116+
"pkgutil.resolve_name",
117+
return_value=MagicMock(__file__=os.path.join("langfair", "__init__.py")),
118+
),
119+
patch("random.sample", return_value=["test"]),
120+
):
121+
result = AdversarialGenerator._read_toxicity_data("toxic", "benign", 1, 0.1)
122+
assert result[0].startswith(INSTRUCTION_DICT["benign"])
123+
124+
125+
def test_read_toxicity_data_nontoxic():
126+
mock_lines = [
127+
json.dumps({"prompt": {"text": "test", "toxicity": 0.05}, "challenging": False})
128+
]
129+
with (
130+
patch("builtins.open", mock_open(read_data="\n".join(mock_lines))),
131+
patch(
132+
"pkgutil.resolve_name",
133+
return_value=MagicMock(__file__=os.path.join("langfair", "__init__.py")),
134+
),
135+
patch("random.sample", return_value=["test"]),
136+
):
137+
result = AdversarialGenerator._read_toxicity_data("nontoxic", "benign", 1, 0.1)
138+
assert result[0].startswith(INSTRUCTION_DICT["benign"])

0 commit comments

Comments
 (0)