diff --git a/benchmark/Cpp/toy/RACE/race.c b/benchmark/Cpp/toy/RACE/race.c new file mode 100644 index 0000000..9cb08a7 --- /dev/null +++ b/benchmark/Cpp/toy/RACE/race.c @@ -0,0 +1,106 @@ +//gcc chall.c -m32 -pie -fstack-protector-all -o chall + +#include +#include +#include +#include + +unsigned int a = 0; +unsigned int b = 0; +unsigned int a_sleep = 0; +int flag = 1; +int pstr1 = 1; +int ret1; +pthread_t th1; +void * th_ret = NULL; + +void menu_go(){ + if(a_sleep == 0){ + a = a + 5; + }else{ + a_sleep = 0; + } + + b = b + 2; +} + +int *menu_chance(){ + if(a<=b){ + puts("No"); + return 0; + } + + if(flag == 1){ + a_sleep = 1; + sleep(1); + flag = 0; + } + else{ + puts("Only have one chance"); + } + return 0; +} + + +void menu_test(){ + if( b>a ){ + puts("Win!"); + system("/bin/sh"); + exit(0); + }else{ + puts("Lose!"); + exit(0); + } +} + +void menu_exit(){ + puts("Bye"); + exit(0); +} + +void menu(){ + printf("***** race *****\n"); + printf("*** 1:Go\n*** 2:Chance\n*** 3:Test\n*** 4:Exit \n"); + printf("*************************************\n"); + printf("Choice> "); + int choose; + scanf("%d",&choose); + switch(choose) + { + case 1: + menu_go(); + break; + case 2: + ret1 = pthread_create(&th1, NULL, menu_chance, &pstr1); + break; + case 3: + menu_test(); + break; + case 4: + menu_exit(); + break; + default: + return; + } + return; + +} + + +void init(){ + setbuf(stdin, 0LL); + setbuf(stdout, 0LL); + setbuf(stderr, 0LL); + + while (1) + { + menu(); + } + +} + +int main(){ + init(); + return 0; +} + diff --git a/lib/build.py b/lib/build.py index f563469..db69e4f 100644 --- a/lib/build.py +++ b/lib/build.py @@ -1,8 +1,8 @@ import os - -from tree_sitter import Language, Parser from pathlib import Path +from tree_sitter import Language + cwd = Path(__file__).resolve().parent.absolute() # clone tree-sitter if necessary diff --git a/requirements.txt b/requirements.txt index 78dfb30..832e6cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ torch tiktoken replicate openai -google-generativeai +google-genai tqdm networkx streamlit diff --git a/src/agent/dfbscan.py b/src/agent/dfbscan.py index 41bc244..4310078 100644 --- a/src/agent/dfbscan.py +++ b/src/agent/dfbscan.py @@ -17,6 +17,7 @@ from tstool.dfbscan_extractor.Cpp.Cpp_MLK_extractor import * from tstool.dfbscan_extractor.Cpp.Cpp_NPD_extractor import * from tstool.dfbscan_extractor.Cpp.Cpp_UAF_extractor import * +from tstool.dfbscan_extractor.Cpp.Cpp_Race_extractor import * from tstool.dfbscan_extractor.Java.Java_NPD_extractor import * from tstool.dfbscan_extractor.Python.Python_NPD_extractor import * from tstool.dfbscan_extractor.Go.Go_NPD_extractor import * @@ -103,6 +104,8 @@ def __obtain_extractor(self) -> DFBScanExtractor: return Cpp_NPD_Extractor(self.ts_analyzer) elif self.bug_type == "UAF": return Cpp_UAF_Extractor(self.ts_analyzer) + elif self.bug_type == "RACE": + return Cpp_Race_Extractor(self.ts_analyzer) elif self.language == "Java": if self.bug_type == "NPD": return Java_NPD_Extractor(self.ts_analyzer) diff --git a/src/llmtool/LLM_utils.py b/src/llmtool/LLM_utils.py index 843c2db..bcb3f14 100644 --- a/src/llmtool/LLM_utils.py +++ b/src/llmtool/LLM_utils.py @@ -2,7 +2,8 @@ from openai import * from pathlib import Path from typing import Tuple -import google.generativeai as genai +from google import genai +from google.genai import types import anthropic import signal import sys @@ -87,27 +88,53 @@ def run_with_timeout(self, func, timeout): ("Operation timed out") return "" except Exception as e: + self.logger.print_console(f"Operation failed: {e}") self.logger.print_log(f"Operation failed: {e}") return "" def infer_with_gemini(self, message: str) -> str: - """Infer using the Gemini model from Google Generative AI""" - gemini_model = genai.GenerativeModel("gemini-pro") + """Infer using the latest Gemini SDK (google-genai)""" + api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") + + if not api_key: + raise EnvironmentError( + "Please set the GOOGLE_API_KEY or GEMINI_API_KEY environment variable." + ) + + client = genai.Client(api_key=api_key) + + model_name = self.online_model_name + if model_name == "gemini-pro": + model_name = "gemini-2.0-flash" def call_api(): - message_with_role = self.systemRole + "\n" + message safety_settings = [ - { - "category": "HARM_CATEGORY_DANGEROUS", - "threshold": "BLOCK_NONE", - }, - # ...existing safety settings... + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=types.HarmBlockThreshold.BLOCK_NONE, + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=types.HarmBlockThreshold.BLOCK_NONE, + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=types.HarmBlockThreshold.BLOCK_NONE, + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.BLOCK_NONE, + ), ] - response = gemini_model.generate_content( - message_with_role, - safety_settings=safety_settings, - generation_config=genai.types.GenerationConfig( - temperature=self.temperature + + response = client.models.generate_content( + model=model_name, + contents=message, + config=types.GenerateContentConfig( + system_instruction=self.systemRole, + temperature=self.temperature, + max_output_tokens=self.max_output_length, + safety_settings=safety_settings, ), ) return response.text @@ -118,14 +145,13 @@ def call_api(): try: output = self.run_with_timeout(call_api, timeout=50) if output: - self.logger.print_log("Inference succeeded...") + self.logger.print_log(f"Gemini ({model_name}) inference succeeded...") return output except Exception as e: - self.logger.print_log(f"API error: {e}") + self.logger.print_log(f"Gemini API error: {e}") time.sleep(2) return "" - def infer_with_openai_model(self, message): """Infer using the OpenAI model""" api_key = os.environ.get("OPENAI_API_KEY").split(":")[0] @@ -188,7 +214,8 @@ def infer_with_deepseek_model(self, message): """ Infer using the DeepSeek model """ - api_key = os.environ.get("DEEPSEEK_API_KEY2") + self.logger.print_console(f"Calling DeepSeek API ({self.online_model_name})...") + api_key = os.environ.get("DEEPSEEK_API_KEY") or os.environ.get("OPENAI_API_KEY2") model_input = [ { "role": "system", @@ -298,6 +325,7 @@ def call_api(): def infer_with_claude_key(self, message): """Infer using the Claude model via API key, with thinking mode for 3.7""" + self.logger.print_console(f"Calling Claude API ({self.online_model_name})...") api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: raise EnvironmentError( diff --git a/src/prompt/Cpp/dfbscan/path_validator.json b/src/prompt/Cpp/dfbscan/path_validator.json index 48f94cf..86d8929 100644 --- a/src/prompt/Cpp/dfbscan/path_validator.json +++ b/src/prompt/Cpp/dfbscan/path_validator.json @@ -10,6 +10,7 @@ "- If the function exits or returns before reaching the sink or relevant propagation sites (such as call sites), then the path is unreachable, so answer No.", "- Analyze the conditions on each sub-path within a function. You should infer the outcome of these conditions from branch details and then check whether the conditions across sub-paths conflict. If they do, then the overall path is unreachable.", "- Examine the values of relevant variables. If those values contradict the related branch conditions necessary to trigger the bug, the path is unreachable and you should answer No.", + "- In the RACE detection, if the shared resource is accessed within a critical section (protected by locks like mutex) or if the access is atomic, then consider the path safe (unreachable for bug) and answer No.", "In summary, evaluate the condition of each sub-path, verify possible conflicts, and then decide whether the entire propagation path is reachable." ], "question_template": [ @@ -132,6 +133,31 @@ "2. In the 'flag' branch, the condition at line 5 checks if p is not NULL.", "3. Since p remains NULL, the condition fails and the else branch at line 7 is executed, preventing any dereference at line 6.", "Therefore, this guarded path is unreachable and does not cause the NPD bug.", + "Answer: No.", + "", + "Example 5:", + "User:", + "Consider the following program which updates a shared resource based on a condition:", + "```cpp", + "1. int g_resource = 0;", + "2. std::mutex resource_mutex;", + "3. void process_resource() {", + "4. // Start of critical section", + "5. resource_mutex.lock();", + "6. if (g_resource < 100) { // Read access", + "7. g_resource += 5; // Write access", + "8. }", + "9. resource_mutex.unlock();", + "10. // End of critical section", + "11. }", + "```", + "Does the following propagation path cause the RACE bug?", + "`g_resource` at line 1 --> `g_resource += 5` at line 7", + "Explanation:", + "1. The code performs a 'Check-Then-Act' operation (Read at line 6, Write at line 7) on the shared variable `g_resource`.", + "2. Both the Read and Write operations are strictly enclosed between `resource_mutex.lock()` (line 5) and `resource_mutex.unlock()` (line 9).", + "3. This constitutes a valid critical section. The mutex ensures atomicity: no other thread can modify `g_resource` between the check (line 6) and the update (line 7).", + "4. Since the access is fully protected by a lock, it is thread-safe.", "Answer: No." ], "additional_fact": [ diff --git a/src/repoaudit.py b/src/repoaudit.py index 24d3639..746ec48 100644 --- a/src/repoaudit.py +++ b/src/repoaudit.py @@ -14,7 +14,7 @@ from typing import List default_dfbscan_checkers = { - "Cpp": ["MLK", "NPD", "UAF"], + "Cpp": ["MLK", "NPD", "UAF", "RACE"], "Java": ["NPD"], "Python": ["NPD"], "Go": ["NPD"], diff --git a/src/run_repoaudit.sh b/src/run_repoaudit.sh index c17ead8..bf96bc4 100755 --- a/src/run_repoaudit.sh +++ b/src/run_repoaudit.sh @@ -3,10 +3,11 @@ set -euo pipefail IFS=$'\n\t' # --- Defaults --- -LANGUAGE="Python" -MODEL="claude-3.7" +LANGUAGE="Cpp" +MODEL="deepseek-chat" +# MODEL="claude-3.7" DEFAULT_PROJECT_NAME="toy" -DEFAULT_BUG_TYPE="NPD" # allowed: MLK, NPD, UAF +DEFAULT_BUG_TYPE="RACE" # allowed: MLK, NPD, UAF, RACE SCAN_TYPE="dfbscan" # Construct the default project *path* from LANGUAGE + DEFAULT_PROJECT_NAME @@ -19,12 +20,13 @@ Usage: run_scan.sh [PROJECT_PATH] [BUG_TYPE] Arguments: PROJECT_PATH Optional absolute/relative path to the subject project. Defaults to: ../benchmark/Python/toy - BUG_TYPE Optional bug type. One of: MLK, NPD, UAF. Defaults to: NPD + BUG_TYPE Optional bug type. One of: MLK, NPD, UAF, RACE. Defaults to: NPD Bug type meanings: MLK - Memory Leak NPD - Null Pointer Dereference UAF - Use After Free + RACE - Race Condition Examples: ./run_scan.sh @@ -48,10 +50,10 @@ BUG_TYPE="$(echo "$BUG_TYPE_RAW" | tr '[:lower:]' '[:upper:]')" # --- Validate BUG_TYPE --- case "$BUG_TYPE" in - MLK|NPD|UAF) : ;; + MLK|NPD|UAF|RACE) : ;; *) - echo "Error: BUG_TYPE must be one of: MLK, NPD, UAF (got '$BUG_TYPE_RAW')." >&2 - echo " MLK = Memory Leak; NPD = Null Pointer Dereference; UAF = Use After Free." >&2 + echo "Error: BUG_TYPE must be one of: MLK, NPD, UAF, RACE (got '$BUG_TYPE_RAW')." >&2 + echo " MLK = Memory Leak; NPD = Null Pointer Dereference; UAF = Use After Free; RACE = Race Condition." >&2 exit 1 ;; esac diff --git a/src/tstool/analyzer/TS_analyzer.py b/src/tstool/analyzer/TS_analyzer.py index 31118ab..0905eaf 100644 --- a/src/tstool/analyzer/TS_analyzer.py +++ b/src/tstool/analyzer/TS_analyzer.py @@ -3,6 +3,7 @@ from pathlib import Path import copy import concurrent.futures +import ctypes from typing import List, Optional, Tuple, Dict, Set from abc import ABC, abstractmethod @@ -148,18 +149,41 @@ def __init__( # Initialize tree-sitter parser self.parser = Parser() self.language_name = language_name - if language_name == "C": - self.language = Language(str(language_path), "c") - elif language_name == "Cpp": - self.language = Language(str(language_path), "cpp") - elif language_name == "Java": - self.language = Language(str(language_path), "java") - elif language_name == "Python": - self.language = Language(str(language_path), "python") - elif language_name == "Go": - self.language = Language(str(language_path), "go") - else: - raise ValueError("Invalid language setting") + + # Load the language library + # Note: Language(path, name) is deprecated in tree-sitter 0.21.x. + # We use ctypes to load the library and get the language pointer to avoid the warning. + try: + lib = ctypes.cdll.LoadLibrary(str(language_path)) + lang_map = { + "C": ("tree_sitter_c", "c"), + "Cpp": ("tree_sitter_cpp", "cpp"), + "Java": ("tree_sitter_java", "java"), + "Python": ("tree_sitter_python", "python"), + "Go": ("tree_sitter_go", "go"), + } + if language_name in lang_map: + func_name, lang_id = lang_map[language_name] + func = getattr(lib, func_name) + func.restype = ctypes.c_void_p + self.language = Language(func(), lang_id) + else: + raise ValueError(f"Unsupported language: {language_name}") + except Exception: + # Fallback to deprecated way if ctypes loading fails to ensure stability + if language_name == "C": + self.language = Language(str(language_path), "c") + elif language_name == "Cpp": + self.language = Language(str(language_path), "cpp") + elif language_name == "Java": + self.language = Language(str(language_path), "java") + elif language_name == "Python": + self.language = Language(str(language_path), "python") + elif language_name == "Go": + self.language = Language(str(language_path), "go") + else: + raise ValueError("Invalid language setting") + self.parser.set_language(self.language) # Results of parsing diff --git a/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py b/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py new file mode 100644 index 0000000..d58bb69 --- /dev/null +++ b/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py @@ -0,0 +1,133 @@ +from tstool.analyzer.TS_analyzer import * +from tstool.analyzer.Cpp_TS_analyzer import * +from ..dfbscan_extractor import * + + +class Cpp_Race_Extractor(DFBScanExtractor): + def extract_sources(self, function: Function) -> List[Value]: + """ + Extract potential shared resources or thread creation points as sources. + 1. Global variables (shared across threads). + 2. Static variables (shared across function calls/threads). + 3. Arguments passed to thread creation functions (std::thread, pthread_create). + """ + root_node = function.parse_tree_root_node + source_code = self.ts_analyzer.code_in_files[function.file_path] + file_path = function.file_path + + sources = [] + + # 1. Find global variables (defined at the top level of the file) + # Note: function.parse_tree_root_node is usually the function body. + # To find global variables, we need to access the root node of the file's AST. + # However, the current architecture passes a 'Function' object. + # We will try to parse the whole file content to find global variables if possible, + # or rely on the fact that TSAnalyzer might have parsed the whole file. + + # Re-parsing the file to find global declarations + parser = Parser() + parser.set_language(self.ts_analyzer.language) + tree = parser.parse(bytes(source_code, "utf8")) + file_root_node = tree.root_node + + declarations = find_nodes_by_type(file_root_node, "declaration") + for decl in declarations: + # Check if the declaration is at the top level (parent is translation_unit) + if decl.parent.type == "translation_unit": + init_declarators = find_nodes_by_type(decl, "init_declarator") + for init_decl in init_declarators: + declarator = init_decl.child_by_field_name("declarator") + while declarator.type in ["pointer_declarator", "reference_declarator"]: + declarator = declarator.child_by_field_name("declarator") + + if declarator and declarator.type == "identifier": + name = source_code[declarator.start_byte:declarator.end_byte] + line_number = source_code[:declarator.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + + # 2. Find static variables within the function + func_declarations = find_nodes_by_type(root_node, "declaration") + for decl in func_declarations: + is_static = False + for child in decl.children: + if child.type == "storage_class_specifier" and source_code[child.start_byte:child.end_byte] == "static": + is_static = True + break + + if is_static: + init_declarators = find_nodes_by_type(decl, "init_declarator") + for init_decl in init_declarators: + declarator = init_decl.child_by_field_name("declarator") + while declarator.type in ["pointer_declarator", "reference_declarator"]: + declarator = declarator.child_by_field_name("declarator") + + if declarator and declarator.type == "identifier": + name = source_code[declarator.start_byte:declarator.end_byte] + line_number = source_code[:declarator.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + + # 3. Find arguments to thread creation + call_expressions = find_nodes_by_type(root_node, "call_expression") + for call in call_expressions: + func_node = call.child_by_field_name("function") + if func_node: + func_name = source_code[func_node.start_byte:func_node.end_byte] + if "thread" in func_name or "async" in func_name or "pthread_create" in func_name: + args = call.child_by_field_name("arguments") + if args: + for arg in args.children: + if arg.type == "identifier": + name = source_code[arg.start_byte:arg.end_byte] + line_number = source_code[:arg.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + elif arg.type == "reference_expression": # std::ref(x) + for child in arg.children: + if child.type == "identifier": + name = source_code[child.start_byte:child.end_byte] + line_number = source_code[:child.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + elif arg.type == "unary_expression": # &x + for child in arg.children: + if child.type == "identifier": + name = source_code[child.start_byte:child.end_byte] + line_number = source_code[:child.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + + return sources + + def extract_sinks(self, function: Function) -> List[Value]: + """ + Extract potential sinks for Race Condition. + We consider ANY access (Read or Write) to a variable as a potential sink. + This allows the LLM to detect Read-Write races. + """ + root_node = function.parse_tree_root_node + source_code = self.ts_analyzer.code_in_files[function.file_path] + file_path = function.file_path + + sinks = [] + + # Extract all identifiers that are used in expressions + # This is a broad extraction, but necessary to catch reads. + # We filter out function calls and declarations to focus on variable usage. + + identifiers = find_nodes_by_type(root_node, "identifier") + for ident in identifiers: + # Filter out declarations (we only want usage) + parent = ident.parent + if parent.type in ["function_declarator", "init_declarator", "declaration", "parameter_declaration"]: + continue + + # Filter out function calls (the function name itself) + if parent.type == "call_expression" and parent.child_by_field_name("function") == ident: + continue + + # Filter out field access (member variables) - simplistic handling + # if parent.type == "field_expression" and parent.child_by_field_name("field") == ident: + # continue + + name = source_code[ident.start_byte:ident.end_byte] + line_number = source_code[:ident.start_byte].count("\n") + 1 + sinks.append(Value(name, line_number, ValueLabel.SINK, file_path)) + + return sinks