Skip to content

Commit d2cf708

Browse files
committed
Set up initial project structure
Also add bare-bones test which can be run like so: ``` uvx pytest tests/test_postprocess.py ``` as long as `uvx` is in path.
1 parent 01b2aee commit d2cf708

File tree

15 files changed

+1199
-0
lines changed

15 files changed

+1199
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
**/tests/
1111
/build
1212
*.pyc
13+
**/__pycache__
14+
*.egg-info/
1315
.vagrant
1416
**/compile_commands.json
1517
.python-version

c2rust-postprocess/README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# LLM-based postprocessing of c2rust transpiler output
2+
3+
This is currently a prototype effort to gauge the extent to which LLMs can
4+
accelerate the types of translation and migration that help move C code to Rust.
5+
6+
# Prerequisites
7+
8+
- Python 3.12 or later
9+
- `uv` in path
10+
- A valid `GEMINI_API_KEY` set
11+
- A transpiled codebase with a correct `compile_commands.json`
12+
13+
# Running
14+
15+
- `c2rust-postprocess`, or
16+
- `uv run postprocess`
17+
18+
# Testing
19+
20+
## Test prerequisites
21+
22+
- `bear` and `c2rust` in path
23+
24+
```
25+
uv run pytest -v
26+
uv run pytest -v tests/test_utils.py # filter tests to run
27+
```
28+
29+
## Misc
30+
31+
- `uv run ruff check --fix .` to format & lint
32+
33+
# TODOs
34+
35+
- testable prototype
36+
- pluggable support for getting definitions
37+
- gemini api support
38+
- filtering by file and function name
39+
- file-based caching of model responses
40+
- openai model support
41+
- antropic model support
42+
- openrouter API support?
43+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/bin/sh
2+
uv run postproc
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
c2rust-postprocess: Transfer comments from C functions to Rust functions using LLMs.
3+
"""
4+
5+
6+
def transfer_comments():
7+
pass
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import argparse
2+
import logging
3+
import sys
4+
from collections.abc import Sequence
5+
6+
from postprocess.utils import existing_file
7+
8+
9+
def build_arg_parser() -> argparse.ArgumentParser:
10+
parser = argparse.ArgumentParser(
11+
description="Transfer C function comments to Rust using LLMs.",
12+
)
13+
parser.add_argument(
14+
"compile_commands",
15+
type=existing_file,
16+
help="Path to compile_commands.json.",
17+
)
18+
19+
parser.add_argument(
20+
"--log-level",
21+
type=str,
22+
required=False,
23+
default="INFO",
24+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
25+
help="Logging level (default: INFO)",
26+
)
27+
28+
return parser
29+
30+
31+
def main(argv: Sequence[str] | None = None) -> int:
32+
parser = build_arg_parser()
33+
args = parser.parse_args(argv)
34+
35+
logging.basicConfig(level=logging.getLevelName(args.log_level.upper()))
36+
37+
_compile_commands = args.compile_commands
38+
# TODO: implement post-processing logic using compile_commands
39+
40+
41+
42+
return 0
43+
44+
45+
if __name__ == "__main__":
46+
sys.exit(main())
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import json
2+
import logging
3+
from pathlib import Path
4+
from typing import Any
5+
6+
import tree_sitter_rust as tsrust
7+
from tree_sitter import Language, Parser
8+
9+
10+
def get_c_sourcefile(compile_commands, rustfile: Path) -> Path | None:
11+
c_file_guesses = [rustfile.with_suffix(".c"), rustfile.with_suffix(".C")]
12+
13+
files = [Path(d["file"]) for d in compile_commands]
14+
15+
for guess in c_file_guesses:
16+
if guess in files:
17+
return guess
18+
19+
return None
20+
21+
22+
def get_rust_function_spans(rustfile: Path) -> list[dict[str, Any]]:
23+
LANGUAGE = Language(tsrust.language())
24+
parser = Parser(LANGUAGE)
25+
26+
if not rustfile.exists():
27+
raise FileNotFoundError(f"{rustfile} does not exist")
28+
if not rustfile.is_file():
29+
raise NotADirectoryError(f"{rustfile} is not a file")
30+
31+
try:
32+
with open(rustfile, "rb") as rust_source:
33+
source_bytes = rust_source.read()
34+
except OSError as exc:
35+
logging.error(f"Failed to read Rust file {rustfile}: {exc}")
36+
return []
37+
38+
tree = parser.parse(source_bytes)
39+
40+
functions = []
41+
42+
for node in tree.root_node.children:
43+
if node.type == 'function_item':
44+
name_node = node.child_by_field_name('name')
45+
func_name = (source_bytes[
46+
name_node.start_byte: # type: ignore
47+
name_node.end_byte # type: ignore
48+
].decode('utf-8'))
49+
50+
functions.append({
51+
"name": func_name,
52+
"start_line": node.start_point[0] + 1, # 0-indexed
53+
"end_line": node.end_point[0] + 1, # 0-indexed
54+
"start_byte": node.start_byte,
55+
"end_byte": node.end_byte
56+
})
57+
58+
return functions
59+
60+
61+
def get_c_functions_spans(compile_commands: dict[str, Any], c_file: Path):
62+
from .clang import get_c_ast_as_json, get_functions_from_clang_ast
63+
cmd = (c for c in compile_commands if c["file"] == str(c_file))
64+
entry = next(cmd, None)
65+
66+
assert entry is not None, f"No compile command entry for {c_file}"
67+
68+
c_fn_asts = get_functions_from_clang_ast(get_c_ast_as_json(entry))
69+
70+
# print(json.dumps(c_fn_asts, indent=4))
71+
72+
functions = []
73+
for fn in c_fn_asts:
74+
loc = fn["loc"]
75+
if "line" in loc and "col" in loc:
76+
functions.append({
77+
"name": fn["name"],
78+
"start_line": loc["line"],
79+
"start_byte": fn["range"]["begin"]["offset"],
80+
"end_line": fn["range"]["end"]["line"],
81+
"end_byte": fn["range"]["end"]["offset"],
82+
})
83+
84+
return functions
85+
86+
87+
def get_function_span_pairs(compile_commands: dict[str, Any], rustfile: Path) -> list[tuple[dict[str, Any], dict[str, Any]]]:
88+
"""Get pairs of Rust and C function spans for the given Rust file."""
89+
90+
rust_fn_spans = get_rust_function_spans(rustfile)
91+
c_file = get_c_sourcefile(compile_commands, rustfile)
92+
if not c_file:
93+
raise FileNotFoundError(f"No corresponding C source file found for {rustfile}")
94+
95+
c_fn_spans = get_c_functions_spans(compile_commands, c_file)
96+
97+
# TODO: handle cases where ordering or counts differ
98+
# A reasonable assumption is that we can still pair functions by name
99+
# which means that this tool needs to run fairly soon after transpilation
100+
assert len(c_fn_spans) == len(rust_fn_spans), "Mismatched number of functions between Rust and C source files"
101+
for rust_fn, c_fn in zip(rust_fn_spans, c_fn_spans):
102+
assert rust_fn['name'] == c_fn['name']
103+
104+
return list(zip(rust_fn_spans, c_fn_spans))
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
2+
import json
3+
import subprocess
4+
from typing import Any
5+
6+
import jq
7+
8+
9+
def get_functions_from_clang_ast(ast: dict[str, Any]) -> list[dict[str, Any]]:
10+
"""
11+
Extract function declarations from the Clang AST JSON.
12+
Args:
13+
ast (dict): The AST JSON as a dictionary.
14+
Returns:
15+
list[dict]: A list of dictionaries, each representing a function declaration.
16+
"""
17+
query = jq.compile(
18+
'.inner[] | select(.kind =="FunctionDecl") | {name: .name, loc: .loc, range: .range}'
19+
)
20+
return query.transform(ast, multiple_output=True)
21+
22+
23+
def get_c_ast_as_json(entry: dict[str, Any]) -> dict[str, Any]:
24+
"""
25+
Get AST as JSON for a translation unit identified by compile commands entry.
26+
"""
27+
source_file = entry["file"]
28+
29+
cmd = entry["arguments"]
30+
cmd[0] = "clang" # make sure we use clang
31+
# drop the last four elements which are the output options
32+
cmd = cmd[:-4] # TODO: validate that these are the output options
33+
# add the necessary flags to dump the AST as JSON
34+
cmd += [
35+
"-fsyntax-only",
36+
"-Xclang",
37+
"-ast-dump=json",
38+
"-fparse-all-comments", # NOTE: Clang AST only includes doc comments
39+
source_file,
40+
]
41+
try:
42+
# cwd to the directory from the compile_commands.json entry to make sure
43+
# relative paths in the command work correctly
44+
result = subprocess.run(
45+
cmd, capture_output=True, text=True, check=True, cwd=entry["directory"]
46+
)
47+
return json.loads(result.stdout)
48+
except subprocess.CalledProcessError as e:
49+
print(f"Error running clang on {source_file}: {e.stderr}")
50+
raise
51+
52+
53+
def is_entry_from_c_file(entry: dict[str, Any], c_file: str) -> bool:
54+
"""
55+
Check if the entry is from the specified C file.
56+
"""
57+
loc = entry["loc"]
58+
if "file" in loc:
59+
return loc["file"] == c_file
60+
elif "spellingLoc" in loc and "includedFrom" in loc["spellingLoc"]:
61+
return loc["spellingLoc"]["includedFrom"]["file"] == c_file
62+
elif "expansionLoc" in loc and "includedFrom" in loc["expansionLoc"]:
63+
return loc["expansionLoc"]["includedFrom"]["file"] == c_file
64+
if "includedFrom" not in loc:
65+
return (
66+
True # entry was parsed from c_file so by default it is from that file
67+
)
68+
return False
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Dict, Callable, Optional, Union, Any
3+
4+
5+
# Type alias for chat messages, e.g., [{"role": "user", "content": "hello"}]
6+
MessageList = List[Dict[str, str]]
7+
ValidationFunction = Callable[[str], bool]
8+
9+
class AbstractModel(ABC):
10+
"""Abstract base class for LLM clients."""
11+
12+
def __init__(self, model_id: str, **kwargs: Any):
13+
"""
14+
Args:
15+
model_id: The specific model identifier (e.g., 'gpt-5').
16+
**kwargs: Configuration like temperature, max_tokens, api_key.
17+
"""
18+
self._model_id = model_id
19+
self._config = kwargs
20+
21+
@property
22+
def model_id(self) -> str:
23+
return self._model_id
24+
25+
@abstractmethod
26+
async def agenerate_response(
27+
self,
28+
messages: Union[str, MessageList],
29+
validate_fn: Optional[ValidationFunction] = None
30+
) -> str:
31+
"""
32+
Generate a response asynchronously.
33+
34+
Args:
35+
messages: A single prompt string OR a list of chat messages.
36+
validate_fn: A function that takes the response string and returns
37+
True if valid, False otherwise.
38+
39+
Returns:
40+
The string content of the LLM response.
41+
42+
Raises:
43+
LLMGenerationError: If the API call fails.
44+
"""
45+
pass
46+
47+
def generate_response(
48+
self,
49+
messages: Union[str, MessageList],
50+
validate_fn: Optional[Callable[[str], bool]] = None
51+
) -> str:
52+
"""Synchronous wrapper for agenerate_response."""
53+
import asyncio
54+
return asyncio.run(self.agenerate_response(messages, validate_fn))
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
3+
class CommentTransfer:
4+
5+
def __init__(self):
6+
pass
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
5+
def existing_file(value: str) -> Path:
6+
path = Path(value)
7+
if path.is_file():
8+
return path
9+
raise argparse.ArgumentTypeError(f"{value!r} is not a readable file")
10+
11+
12+
def get_rust_files(path: Path) -> list[Path]:
13+
rust_files = []
14+
15+
if not path.exists():
16+
raise FileNotFoundError(f"{path} does not exist")
17+
if not path.is_dir():
18+
raise NotADirectoryError(f"{path} is not a directory")
19+
20+
for file in path.glob("**/*.rs"):
21+
rust_files.append(file)
22+
return rust_files
23+
24+
25+
def read_chunk(filepath: Path, start_offset: int, end_offset: int, encoding='utf-8'):
26+
if start_offset < 0 or end_offset < start_offset:
27+
raise ValueError(f"Invalid range: {start_offset}{end_offset}")
28+
29+
length = end_offset - start_offset + 1 # inclusive range
30+
31+
with open(filepath, 'rb') as f: # Only byte mode supports seeking to byte offset
32+
f.seek(start_offset)
33+
return f.read(length).decode(encoding)

0 commit comments

Comments
 (0)