Skip to content

Commit 43f6769

Browse files
committed
Add Attack Type
1 parent 7e1ae2c commit 43f6769

File tree

4 files changed

+151
-24
lines changed

4 files changed

+151
-24
lines changed

_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ class Focus(str, Enum):
1111
Safety = "Safety"
1212
Other = "Other"
1313

14+
class AttackType(str, Enum):
15+
ModelEvasion = "Evasion"
16+
ModelExtraction = "Extraction"
17+
ModelInversion = "Inversion"
18+
ModelPoisoning = "Poisoning"
19+
PromptInjection = "Prompt Injection"
20+
Other = "Other"
21+
1422

1523
@dataclass
1624
class Paper:
@@ -21,6 +29,7 @@ class Paper:
2129
title: str | None = None
2230
url: str | None = None
2331
focus: Focus | None = None
32+
attack_type: AttackType | None = None
2433
summary: str | None = None
2534
abstract: str | None = None
2635
authors: list[str] = field(default_factory=list)

notion_utils.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from notion_client.helpers import async_collect_paginated_api
77
from tqdm import tqdm # type: ignore
88

9-
from _types import Paper, Focus
9+
from _types import AttackType, Paper, Focus
1010

1111
NotionClient = AsyncClient
1212

@@ -15,10 +15,14 @@ def get_notion_client(token: str) -> NotionClient:
1515
return NotionClient(auth=token)
1616

1717

18-
async def get_papers_from_notion(client: NotionClient, database_id: str) -> list[Paper]:
19-
results = await async_collect_paginated_api(
20-
client.databases.query, database_id=database_id
21-
)
18+
async def get_papers_from_notion(client: NotionClient, database_id: str, *, max: int | None = None) -> list[Paper]:
19+
if max:
20+
results = await client.databases.query(database_id=database_id, page_size=max)
21+
results = results['results']
22+
else:
23+
results = await async_collect_paginated_api(
24+
client.databases.query, database_id=database_id
25+
)
2226

2327
papers: list[Paper] = []
2428
for result in results:
@@ -35,6 +39,8 @@ async def get_papers_from_notion(client: NotionClient, database_id: str) -> list
3539
published = datetime.fromisoformat(published["start"]) if published else None
3640
focus = properties["Focus"]["select"]
3741
focus = Focus(focus["name"]) if focus else None
42+
attack_type = properties["Attack Type"]["select"]
43+
attack_type = AttackType(attack_type["name"]) if attack_type else None
3844
explored = properties["Explored"]["checkbox"]
3945

4046
if not any([url, title]):
@@ -46,6 +52,7 @@ async def get_papers_from_notion(client: NotionClient, database_id: str) -> list
4652
title=title,
4753
url=url,
4854
focus=focus,
55+
attack_type=attack_type,
4956
summary=summary,
5057
authors=authors,
5158
published=published,
@@ -62,23 +69,25 @@ async def write_papers_to_notion(
6269
) -> None:
6370
for paper in tqdm(papers):
6471
properties: dict[str, t.Any] = {}
65-
if paper.title:
72+
if paper.title and paper._original_state["title"] != paper.title:
6673
properties["Title"] = {"title": [{"text": {"content": paper.title}}]}
67-
if paper.url:
74+
if paper.url and paper._original_state["url"] != paper.url:
6875
properties["URL"] = {"url": paper.url}
69-
if paper.summary:
76+
if paper.summary and paper._original_state["summary"] != paper.summary:
7077
properties["Summary"] = {
7178
"rich_text": [{"text": {"content": paper.summary}}]
7279
}
73-
if paper.authors:
80+
if paper.authors and paper._original_state["authors"] != paper.authors:
7481
properties["Authors"] = {
75-
"multi_select": [{"name": author} for author in paper.authors]
82+
"multi_select": [{"name": author} for author in paper.authors[:5]] # Limit to 5 authors
7683
}
77-
if paper.published:
84+
if paper.published and paper._original_state["published"] != paper.published:
7885
properties["Published"] = {"date": {"start": paper.published.isoformat()}}
79-
if paper.focus:
86+
if paper.focus and paper._original_state["focus"] != paper.focus:
8087
properties["Focus"] = {"select": {"name": paper.focus.value}}
81-
if paper.explored:
88+
if paper.attack_type and paper._original_state["attack_type"] != paper.attack_type:
89+
properties["Attack Type"] = {"select": {"name": paper.attack_type.value}}
90+
if paper.explored and paper._original_state["explored"] != paper.explored:
8291
properties["Explored"] = {"checkbox": paper.explored}
8392

8493
if paper.page_id:

openai_utils.py

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1+
from _types import AttackType, Focus
12
from openai import OpenAI
23

3-
from _types import Focus
4-
54
OpenAIClient = OpenAI
65

76
SUMMARIZE_ABSTRACT_PROMPT = """\
@@ -20,13 +19,87 @@
2019
Respond with ONLY ONE of the labels above. Do not include anything else in your response.
2120
"""
2221

22+
# Attack Type descriptions
23+
24+
EVASION_DESCRIPTION = """\
25+
Model Evasion is an adversarial attack aimed at bypassing or evading a machine
26+
learning model's defenses, usually to make it produce incorrect outputs or behave
27+
in ways that favor the attacker. In this context, the adversary doesn't try to
28+
"break" the model or extract data from it (like in model inversion) but instead
29+
seeks to manipulate the model's behavior in a way that allows them to achieve a
30+
desired outcome, such as bypassing detection systems or generating misleading predictions.
31+
"""
32+
33+
EXTRACTION_DESCRIPTION = """\
34+
Model Extraction refers to an attack where an adversary tries to replicate or steal
35+
the functionality of a machine learning model by querying it and using the outputs
36+
to build a copy of the original model. This type of attack doesn't necessarily involve
37+
extracting sensitive data used for training, as in model inversion, but instead focuses
38+
on how the model behaves—its predictions and outputs—in order to create a surrogate or
39+
shadow model that behaves similarly to the original.
40+
"""
41+
42+
INVERSION_DESCRIPTION = """\
43+
Model inversion refers to a set of techniques in machine learning where an attacker
44+
tries to extract confidential information from a trained AI model by interacting with
45+
it in specific ways, often through extensive querying. By doing so, the attacker may
46+
be able to infer details about the data used to train the model. These details can
47+
range from personal information to the reconstruction of private or sensitive datasets,
48+
potentially revealing confidential information.
49+
"""
50+
51+
POISONING_DESCRIPTION = """\
52+
Model Poisoning is an attack on machine learning models where an adversary intentionally
53+
manipulates data in the training set to impact how a model behaves. Unlike attacks like
54+
model inversion or model extraction, which focus on extracting information from the model,
55+
model poisoning targets the model during its training phase. By introducing misleading,
56+
incorrect, or adversarial data, attackers can manipulate a model's behavior, often without
57+
detection, leading to significant security, reliability, and ethical risks.
58+
"""
59+
60+
PROMPT_INJECTION_DESCRIPTION = """\
61+
Prompt injection is a critical vulnerability in Large Language Models (LLMs), where malicious
62+
users manipulate model behavior by crafting inputs that override, bypass, or exploit how the
63+
model follows instructions. This vulnerability has become more pronounced with the widespread
64+
use of generative AI systems, enabling attackers to induce unintended responses that may lead
65+
to data leakage, misinformation, or system disruptions.
66+
"""
67+
68+
69+
ATTACK_TYPE_DESCRIPTIONS: dict[AttackType, str] = {
70+
AttackType.ModelEvasion: EVASION_DESCRIPTION,
71+
AttackType.ModelExtraction: EXTRACTION_DESCRIPTION,
72+
AttackType.ModelInversion: INVERSION_DESCRIPTION,
73+
AttackType.ModelPoisoning: POISONING_DESCRIPTION,
74+
AttackType.PromptInjection: PROMPT_INJECTION_DESCRIPTION,
75+
AttackType.Other: "None of the above",
76+
}
77+
78+
ASSIGN_ATTACK_TYPE_PROMPT = """\
79+
You will be provided with an abstract of a scientific paper. \
80+
Assess the most applicable attack type label based on the \
81+
research focus, produced materials, and key outcomes.
82+
83+
{types}
84+
85+
If you feel like none of the types apply, you can respond with "Other".
86+
87+
Respond with ONLY ONE of the labels above. Do not include anything else in your response.
88+
"""
89+
90+
# Model Evasion
91+
# Model Extraction
92+
# Model Inversion
93+
# Model Poisoning
94+
# Prompt Injection
95+
2396
def get_openai_client(token: str) -> OpenAIClient:
2497
return OpenAI(api_key=token)
2598

2699

27100
def summarize_abstract_with_openai(client: OpenAIClient, abstract: str) -> str:
28101
response = client.chat.completions.create(
29-
model="gpt-3.5-turbo",
102+
model="gpt-4o-mini",
30103
messages=[
31104
{"role": "system", "content": SUMMARIZE_ABSTRACT_PROMPT},
32105
{"role": "user", "content": f"{abstract}"},
@@ -35,7 +108,8 @@ def summarize_abstract_with_openai(client: OpenAIClient, abstract: str) -> str:
35108
max_tokens=100,
36109
)
37110

38-
return response.choices[0].message.content.strip() # type: ignore
111+
return response.choices[0].message.content.strip() # type: ignore
112+
39113

40114
def get_focus_label_from_abstract(client: OpenAIClient, abstract: str) -> Focus | None:
41115
system_prompt = ASSIGN_LABEL_PROMPT.format(
@@ -52,8 +126,32 @@ def get_focus_label_from_abstract(client: OpenAIClient, abstract: str) -> Focus
52126
max_tokens=10,
53127
)
54128

55-
content = response.choices[0].message.content.strip() # type: ignore
129+
content = response.choices[0].message.content.strip() # type: ignore
56130
if content not in [f.value for f in Focus]:
57131
return None
58-
59-
return Focus(content)
132+
133+
return Focus(content)
134+
135+
def get_attack_type_from_abstract(client: OpenAIClient, abstract: str) -> AttackType | None:
136+
system_prompt = ASSIGN_ATTACK_TYPE_PROMPT.format(
137+
types="\n".join([f"- `{t.value}`: {ATTACK_TYPE_DESCRIPTIONS[t]}" for t in AttackType])
138+
)
139+
140+
response = client.chat.completions.create(
141+
model="gpt-3.5-turbo",
142+
messages=[
143+
{"role": "system", "content": system_prompt},
144+
{"role": "user", "content": f"{abstract}"},
145+
],
146+
temperature=0.5,
147+
max_tokens=10,
148+
)
149+
150+
content = response.choices[0].message.content.strip() # type: ignore
151+
content = content.strip("`")
152+
153+
if content not in [t.value for t in AttackType]:
154+
print(f"Invalid attack type: {content}")
155+
return None
156+
157+
return AttackType(content)

paperstack.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
write_papers_to_notion,
1111
)
1212
from openai_utils import (
13+
get_attack_type_from_abstract,
1314
get_focus_label_from_abstract,
1415
get_openai_client,
1516
summarize_abstract_with_openai,
@@ -48,7 +49,6 @@ async def main():
4849
parser.add_argument("--search-semantic-scholar", action="store_true", default=False)
4950

5051
args = parser.parse_args()
51-
5252
print("[+] Paperstack")
5353

5454
notion_client = get_notion_client(args.notion_token)
@@ -62,6 +62,9 @@ async def main():
6262
if p.published < datetime.fromisoformat("2024-07-01 00:00:00+00:00"):
6363
p.explored = True
6464

65+
if len(p.authors) > 5:
66+
p.authors = p.authors[:5]
67+
6568
if not all([p.has_arxiv_props() for p in papers]):
6669
print(" |- Filling in missing data from arXiv")
6770
papers = fill_papers_with_arxiv(papers)
@@ -70,7 +73,7 @@ async def main():
7073
print(" |- Searching arXiv for new papers")
7174
existing_titles = [paper.title for paper in papers]
7275
for searched_paper in search_arxiv_as_paper(
73-
args.arxiv_search_query, max_results=50
76+
args.arxiv_search_query, max_results=500
7477
):
7578
if searched_paper.title not in existing_titles:
7679
print(f" |- {searched_paper.title[:50]}...")
@@ -96,10 +99,18 @@ async def main():
9699

97100
if not all([paper.focus for paper in papers]):
98101
print(" |- Assigning focus labels with OpenAI")
99-
for paper in [p for p in papers if not p.focus and p.abstract]:
100-
paper.focus = get_focus_label_from_abstract(openai_client, paper.abstract)
102+
for paper in [p for p in papers if not p.focus and (p.abstract or p.summary)]:
103+
reference = paper.abstract or paper.summary
104+
paper.focus = get_focus_label_from_abstract(openai_client, reference)
101105
print(f" |- {paper.focus}")
102106

107+
if not all([paper.attack_type for paper in papers]):
108+
print(" |- Assigning attack types with OpenAI")
109+
for paper in [p for p in papers if not p.attack_type and (p.abstract or p.summary)]:
110+
reference = paper.abstract or paper.summary
111+
paper.attack_type = get_attack_type_from_abstract(openai_client, reference)
112+
print(f" |- {paper.attack_type}")
113+
103114
to_write = [p for p in papers if p.has_changed()]
104115
if to_write:
105116
print(f" |- Writing {len(to_write)} updates back to Notion")

0 commit comments

Comments
 (0)