Skip to content

Commit 5b517cb

Browse files
committed
Set up initial project structure
Also add bare-bones test which can be run like so: ``` uvx pytest tests/test_postprocess.py ``` as long as `uvx` is in path.
1 parent 01b2aee commit 5b517cb

File tree

16 files changed

+1348
-0
lines changed

16 files changed

+1348
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
**/tests/
1111
/build
1212
*.pyc
13+
**/__pycache__
14+
*.egg-info/
1315
.vagrant
1416
**/compile_commands.json
1517
.python-version

c2rust-postprocess/README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# LLM-based postprocessing of c2rust transpiler output
2+
3+
This is currently a prototype effort to gauge the extent to which LLMs can
4+
accelerate the types of translation and migration that help move C code to Rust.
5+
6+
# Prerequisites
7+
8+
- Python 3.12 or later
9+
- `uv` in path
10+
- A valid `GEMINI_API_KEY` set
11+
- A transpiled codebase with a correct `compile_commands.json`
12+
13+
# Running
14+
15+
- `c2rust-postprocess`, or
16+
- `uv run postprocess`
17+
18+
# Testing
19+
20+
## Test prerequisites
21+
22+
- `bear` and `c2rust` in path
23+
24+
```
25+
uv run pytest -v
26+
uv run pytest -v tests/test_utils.py # filter tests to run
27+
```
28+
29+
## Misc
30+
31+
- `uv run ruff check --fix .` to format & lint
32+
33+
# TODOs
34+
35+
- testable prototype
36+
- pluggable support for getting definitions
37+
- gemini api support
38+
- filtering by file and function name
39+
- file-based caching of model responses
40+
- openai model support
41+
- antropic model support
42+
- openrouter API support?
43+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/bin/sh
2+
uv run postproc
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
c2rust-postprocess: Transfer comments from C functions to Rust functions using LLMs.
3+
"""
4+
5+
6+
from pathlib import Path
7+
from typing import Any
8+
from postprocess.definitions import get_function_span_pairs
9+
from postprocess.utils import read_chunk, get_rust_files, get_compile_commands
10+
11+
# TODO: could also include
12+
# - validation function to check result
13+
# - list of comments to check for
14+
class CommentTransferPrompt:
15+
c_function: str
16+
rust_function: str
17+
prompt_text: str
18+
19+
__slots__ = ("c_function", "rust_function", "prompt_text")
20+
21+
def __init__(self, c_function: str, rust_function: str, prompt_text: str) -> None:
22+
self.c_function = c_function
23+
self.rust_function = rust_function
24+
self.prompt_text = prompt_text
25+
26+
def __str__(self) -> str:
27+
raise NotImplementedError("String conversion not implemented yet.")
28+
29+
30+
31+
def generate_prompts(compile_commands: dict[str, Any], rust_file: Path) -> list[CommentTransferPrompt]:
32+
pairs = get_function_span_pairs(compile_commands, rust_file)
33+
34+
prompts = []
35+
36+
for rust_fn, c_fn in pairs:
37+
c_def = read_chunk(c_fn['file'], c_fn['start_byte'], c_fn['end_byte'])
38+
# TODO: log on verbose level
39+
# print(f"C function {c_fn['name']} definition:\n{c_def}\n")
40+
41+
rust_def = read_chunk(rust_fn['file'], rust_fn['start_byte'], rust_fn['end_byte'])
42+
# TODO: log on verbose level
43+
# print(f"Rust function {rust_fn['name']} definition:\n{rust_def}\n")
44+
45+
# TODO: make this function take a model and get prompt from model
46+
prompt_text = f"""Transfer the comments from the following C function to the corresponding Rust function."""
47+
prompt = CommentTransferPrompt(
48+
c_function=c_def,
49+
rust_function=rust_def,
50+
prompt_text=prompt_text
51+
)
52+
53+
prompts.append(prompt)
54+
55+
return prompts
56+
57+
58+
def transfer_comments(compile_commands_path: Path):
59+
rust_sources = get_rust_files(compile_commands_path.parent)
60+
61+
compile_commands = get_compile_commands(compile_commands_path)
62+
for rust_file in rust_sources:
63+
prompts = generate_prompts(compile_commands, rust_file)
64+
# Call LLM API with the prompt and get the response
65+
# Process the response to extract comments and associate them with Rust functions
66+
pass
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import argparse
2+
import logging
3+
import sys
4+
from collections.abc import Sequence
5+
6+
from postprocess.utils import existing_file
7+
8+
9+
def build_arg_parser() -> argparse.ArgumentParser:
10+
parser = argparse.ArgumentParser(
11+
description="Transfer C function comments to Rust using LLMs.",
12+
)
13+
parser.add_argument(
14+
"compile_commands",
15+
type=existing_file,
16+
help="Path to compile_commands.json.",
17+
)
18+
19+
parser.add_argument(
20+
"--log-level",
21+
type=str,
22+
required=False,
23+
default="INFO",
24+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
25+
help="Logging level (default: INFO)",
26+
)
27+
28+
return parser
29+
30+
31+
def main(argv: Sequence[str] | None = None) -> int:
32+
parser = build_arg_parser()
33+
args = parser.parse_args(argv)
34+
35+
logging.basicConfig(level=logging.getLevelName(args.log_level.upper()))
36+
37+
_compile_commands = args.compile_commands
38+
# TODO: implement post-processing logic using compile_commands
39+
40+
41+
42+
return 0
43+
44+
45+
if __name__ == "__main__":
46+
sys.exit(main())
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import Any
4+
5+
import tree_sitter_rust as tsrust
6+
from tree_sitter import Language, Parser
7+
8+
9+
def get_c_sourcefile(compile_commands, rustfile: Path) -> Path | None:
10+
c_file_guesses = [rustfile.with_suffix(".c"), rustfile.with_suffix(".C")]
11+
12+
files = [Path(d["file"]) for d in compile_commands]
13+
14+
for guess in c_file_guesses:
15+
if guess in files:
16+
return guess
17+
18+
return None
19+
20+
21+
def get_rust_function_spans(rustfile: Path) -> list[dict[str, Any]]:
22+
LANGUAGE = Language(tsrust.language())
23+
parser = Parser(LANGUAGE)
24+
25+
if not rustfile.exists():
26+
raise FileNotFoundError(f"{rustfile} does not exist")
27+
if not rustfile.is_file():
28+
raise NotADirectoryError(f"{rustfile} is not a file")
29+
30+
try:
31+
with open(rustfile, "rb") as rust_source:
32+
source_bytes = rust_source.read()
33+
except OSError as exc:
34+
logging.error(f"Failed to read Rust file {rustfile}: {exc}")
35+
return []
36+
37+
tree = parser.parse(source_bytes)
38+
39+
functions = []
40+
41+
for node in tree.root_node.children:
42+
if node.type == 'function_item':
43+
name_node = node.child_by_field_name('name')
44+
func_name = (source_bytes[
45+
name_node.start_byte: # type: ignore
46+
name_node.end_byte # type: ignore
47+
].decode('utf-8'))
48+
49+
functions.append({
50+
"name": func_name,
51+
"start_line": node.start_point[0] + 1, # 0-indexed
52+
"end_line": node.end_point[0] + 1, # 0-indexed
53+
"start_byte": node.start_byte,
54+
"end_byte": node.end_byte
55+
})
56+
57+
return functions
58+
59+
60+
def get_c_functions_spans(compile_commands: dict[str, Any], c_file: Path):
61+
from .clang import get_c_ast_as_json, get_functions_from_clang_ast
62+
cmd = (c for c in compile_commands if c["file"] == str(c_file))
63+
entry = next(cmd, None)
64+
65+
assert entry is not None, f"No compile command entry for {c_file}"
66+
67+
c_fn_asts = get_functions_from_clang_ast(get_c_ast_as_json(entry))
68+
69+
# print(json.dumps(c_fn_asts, indent=4))
70+
71+
functions = []
72+
for fn in c_fn_asts:
73+
loc = fn["loc"]
74+
if "line" in loc and "col" in loc:
75+
functions.append({
76+
"name": fn["name"],
77+
"start_line": loc["line"],
78+
"start_byte": fn["range"]["begin"]["offset"],
79+
"end_line": fn["range"]["end"]["line"],
80+
"end_byte": fn["range"]["end"]["offset"],
81+
})
82+
83+
return functions
84+
85+
86+
def get_function_span_pairs(compile_commands: dict[str, Any], rustfile: Path) -> list[tuple[dict[str, Any], dict[str, Any]]]:
87+
"""Get pairs of Rust and C function spans for the given Rust file."""
88+
89+
rust_fn_spans = get_rust_function_spans(rustfile)
90+
c_file = get_c_sourcefile(compile_commands, rustfile)
91+
if not c_file:
92+
raise FileNotFoundError(f"No corresponding C source file found for {rustfile}")
93+
94+
c_fn_spans = get_c_functions_spans(compile_commands, c_file)
95+
96+
# TODO: handle cases where ordering or counts differ
97+
# A reasonable assumption is that we can still pair functions by name
98+
# which means that this tool needs to run fairly soon after transpilation
99+
assert len(c_fn_spans) == len(rust_fn_spans), "Mismatched number of functions between Rust and C source files"
100+
for rust_fn, c_fn in zip(rust_fn_spans, c_fn_spans):
101+
assert rust_fn['name'] == c_fn['name']
102+
rust_fn['file'] = rustfile
103+
c_fn['file'] = c_file
104+
105+
return list(zip(rust_fn_spans, c_fn_spans))
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
2+
import json
3+
import subprocess
4+
from typing import Any
5+
6+
import jq
7+
8+
9+
def get_functions_from_clang_ast(ast: dict[str, Any]) -> list[dict[str, Any]]:
10+
"""
11+
Extract function declarations from the Clang AST JSON.
12+
Args:
13+
ast (dict): The AST JSON as a dictionary.
14+
Returns:
15+
list[dict]: A list of dictionaries, each representing a function declaration.
16+
"""
17+
query = jq.compile(
18+
'.inner[] | select(.kind =="FunctionDecl") | {name: .name, loc: .loc, range: .range}'
19+
)
20+
return query.transform(ast, multiple_output=True)
21+
22+
23+
def get_c_ast_as_json(entry: dict[str, Any]) -> dict[str, Any]:
24+
"""
25+
Get AST as JSON for a translation unit identified by compile commands entry.
26+
"""
27+
source_file = entry["file"]
28+
29+
cmd = entry["arguments"]
30+
cmd[0] = "clang" # make sure we use clang
31+
# drop the last four elements which are the output options
32+
cmd = cmd[:-4] # TODO: validate that these are the output options
33+
# add the necessary flags to dump the AST as JSON
34+
cmd += [
35+
"-fsyntax-only",
36+
"-Xclang",
37+
"-ast-dump=json",
38+
"-fparse-all-comments", # NOTE: Clang AST only includes doc comments
39+
source_file,
40+
]
41+
try:
42+
# cwd to the directory from the compile_commands.json entry to make sure
43+
# relative paths in the command work correctly
44+
result = subprocess.run(
45+
cmd, capture_output=True, text=True, check=True, cwd=entry["directory"]
46+
)
47+
return json.loads(result.stdout)
48+
except subprocess.CalledProcessError as e:
49+
print(f"Error running clang on {source_file}: {e.stderr}")
50+
raise
51+
52+
53+
def is_entry_from_c_file(entry: dict[str, Any], c_file: str) -> bool:
54+
"""
55+
Check if the entry is from the specified C file.
56+
"""
57+
loc = entry["loc"]
58+
if "file" in loc:
59+
return loc["file"] == c_file
60+
elif "spellingLoc" in loc and "includedFrom" in loc["spellingLoc"]:
61+
return loc["spellingLoc"]["includedFrom"]["file"] == c_file
62+
elif "expansionLoc" in loc and "includedFrom" in loc["expansionLoc"]:
63+
return loc["expansionLoc"]["includedFrom"]["file"] == c_file
64+
if "includedFrom" not in loc:
65+
return (
66+
True # entry was parsed from c_file so by default it is from that file
67+
)
68+
return False
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Dict, Callable, Optional, Any
3+
4+
5+
class AbstractGenAIModel(ABC):
6+
"""
7+
Abstract base class for LLM clients using Native Function Calling.
8+
"""
9+
10+
def __init__(self, model_id: str, **kwargs: Any):
11+
self._model_id = model_id
12+
self._config = kwargs
13+
14+
@property
15+
def model_id(self) -> str:
16+
return self._model_id
17+
18+
@abstractmethod
19+
async def agenerate_with_tools(
20+
self,
21+
messages: List[Dict[str, Any]],
22+
tools: Optional[List[Callable]] = None,
23+
max_tool_loops: int = 5
24+
) -> Any:
25+
"""
26+
Generate a response using native automatic function calling.
27+
28+
Args:
29+
messages: Chat history.
30+
tools: List of Python functions (the concrete class handles schema conversion).
31+
max_tool_loops: Maximum number of times the model can call tools
32+
consecutively (prevent infinite loops).
33+
34+
Returns:
35+
The final natural language response from the model.
36+
"""
37+
pass
38+
39+
def generate_with_tools(
40+
self,
41+
messages: List[Dict[str, Any]],
42+
tools: Optional[List[Callable]] = None,
43+
max_tool_loops: int = 5
44+
) -> str:
45+
"""Synchronous wrapper for agenerate_response."""
46+
import asyncio
47+
return asyncio.run(self.agenerate_with_tools(messages, tools, max_tool_loops))
48+

0 commit comments

Comments
 (0)