Skip to content

Commit ddb9063

Browse files
committed
feat: add rust data model
1 parent d9e329a commit ddb9063

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

cldk/models/rust/__init__.py

Whitespace-only changes.

cldk/models/rust/models.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from enum import Enum
2+
from typing import Dict, List, Optional, Set
3+
from pydantic import BaseModel, Field
4+
5+
6+
class RustVisibility(Enum):
7+
"""Represents Rust visibility modifiers."""
8+
PUBLIC = "pub"
9+
PRIVATE = ""
10+
CRATE = "pub(crate)"
11+
SUPER = "pub(super)"
12+
IN_PATH = "pub(in path)"
13+
14+
15+
class SafetyClassification(Enum):
16+
"""Classifies the safety level of Rust code."""
17+
SAFE = "safe"
18+
UNSAFE = "unsafe"
19+
UNSAFE_CONTAINER = "unsafe_container"
20+
FFI = "ffi"
21+
22+
23+
class UnsafeReason(Enum):
24+
"""Reasons why code might be unsafe."""
25+
RAW_POINTER_DEREF = "raw_pointer_deref"
26+
MUTABLE_STATIC = "mutable_static"
27+
FFI_CALL = "ffi_call"
28+
UNION_FIELD_ACCESS = "union_field_access"
29+
INLINE_ASSEMBLY = "inline_assembly"
30+
UNSAFE_TRAIT_IMPL = "unsafe_trait_impl"
31+
CUSTOM = "custom"
32+
33+
34+
class UnsafeBlock(BaseModel):
35+
"""Represents an unsafe block within Rust code."""
36+
start_line: int
37+
end_line: int
38+
reasons: List[UnsafeReason] = Field(default_factory=list)
39+
explanation: Optional[str] = None # Documentation explaining why unsafe is needed
40+
containing_function: Optional[str] = None
41+
42+
43+
class RustType(BaseModel):
44+
"""Represents a Rust type."""
45+
name: str
46+
is_reference: bool = False
47+
is_mutable: bool = False
48+
lifetime: Optional[str] = None
49+
generic_params: List[str] = Field(default_factory=list)
50+
is_sized: bool = True
51+
is_static: bool = False
52+
contains_raw_pointers: bool = False
53+
is_union: bool = False
54+
55+
56+
class RustAttribute(BaseModel):
57+
"""Represents a Rust attribute."""
58+
name: str
59+
arguments: List[str] = Field(default_factory=list)
60+
is_inner: bool = False # #![foo] vs #[foo]
61+
62+
63+
class SafetyAnalysis(BaseModel):
64+
"""Analyzes and tracks safety-related information."""
65+
classification: SafetyClassification
66+
unsafe_blocks: List[UnsafeBlock] = Field(default_factory=list)
67+
unsafe_fn_calls: List[str] = Field(default_factory=list)
68+
raw_pointer_usage: bool = False
69+
ffi_interactions: bool = False
70+
unsafe_traits_used: List[str] = Field(default_factory=list)
71+
mutable_statics: List[str] = Field(default_factory=list)
72+
safety_comments: Optional[str] = None
73+
74+
75+
class RustCallable(BaseModel):
76+
"""Represents a Rust function or method."""
77+
name: str
78+
visibility: RustVisibility = RustVisibility.PRIVATE
79+
doc_comment: Optional[str] = None
80+
attributes: List[RustAttribute] = Field(default_factory=list)
81+
parameters: List["RustParameter"] = Field(default_factory=list)
82+
return_type: Optional[RustType] = None
83+
is_async: bool = False
84+
is_const: bool = False
85+
is_unsafe: bool = False
86+
is_extern: bool = False
87+
extern_abi: Optional[str] = None
88+
generic_params: List["RustGenericParam"] = Field(default_factory=list)
89+
lifetime_params: List["RustLifetimeParam"] = Field(default_factory=list)
90+
where_clauses: List[str] = Field(default_factory=list)
91+
code: str
92+
start_line: int
93+
end_line: int
94+
referenced_types: List[str] = Field(default_factory=list)
95+
accessed_variables: List[str] = Field(default_factory=list)
96+
call_sites: List["RustCallSite"] = Field(default_factory=list)
97+
variable_declarations: List["RustVariableDeclaration"] = Field(default_factory=list)
98+
cyclomatic_complexity: Optional[int] = None
99+
safety_analysis: SafetyAnalysis
100+
101+
def is_fully_safe(self) -> bool:
102+
"""Check if the function is completely safe (no unsafe blocks or calls)."""
103+
return (
104+
self.safety_analysis.classification == SafetyClassification.SAFE
105+
and not self.safety_analysis.unsafe_blocks
106+
and not self.safety_analysis.unsafe_fn_calls
107+
)
108+
109+
def contains_unsafe(self) -> bool:
110+
"""Check if the function contains any unsafe code."""
111+
return (
112+
self.safety_analysis.classification in [SafetyClassification.UNSAFE, SafetyClassification.UNSAFE_CONTAINER]
113+
or bool(self.safety_analysis.unsafe_blocks)
114+
or bool(self.safety_analysis.unsafe_fn_calls)
115+
)
116+
117+
def get_unsafe_reasons(self) -> Set[UnsafeReason]:
118+
"""Get all reasons for unsafe usage in this function."""
119+
reasons = set()
120+
for block in self.safety_analysis.unsafe_blocks:
121+
reasons.update(block.reasons)
122+
return reasons
123+
124+
125+
class RustModule(BaseModel):
126+
"""Represents a Rust module."""
127+
name: str
128+
doc_comment: Optional[str] = None
129+
attributes: List[RustAttribute] = Field(default_factory=list)
130+
visibility: RustVisibility = RustVisibility.PRIVATE
131+
types: Dict[str, "RustType"] = Field(default_factory=dict)
132+
functions: Dict[str, RustCallable] = Field(default_factory=dict)
133+
safe_functions: Dict[str, RustCallable] = Field(default_factory=dict)
134+
unsafe_functions: Dict[str, RustCallable] = Field(default_factory=dict)
135+
submodules: Dict[str, 'RustModule'] = Field(default_factory=dict)
136+
constants: List["RustVariableDeclaration"] = Field(default_factory=list)
137+
macros: List[str] = Field(default_factory=list)
138+
use_declarations: List[str] = Field(default_factory=list)
139+
extern_crates: List[str] = Field(default_factory=list)
140+
is_unsafe: bool = False
141+
file_path: Optional[str] = None
142+
is_mod_rs: bool = False
143+
144+
def categorize_functions(self):
145+
"""Categorize functions based on their safety analysis."""
146+
self.safe_functions.clear()
147+
self.unsafe_functions.clear()
148+
149+
for name, func in self.functions.items():
150+
if func.is_fully_safe():
151+
self.safe_functions[name] = func
152+
else:
153+
self.unsafe_functions[name] = func
154+
155+
156+
class RustCrate(BaseModel):
157+
"""Represents a complete Rust crate."""
158+
name: str
159+
version: str
160+
root_module: RustModule
161+
dependencies: List["RustDependencyEdge"] = Field(default_factory=list)
162+
edition: str = "2021"
163+
features: List[str] = Field(default_factory=list)
164+
165+
def analyze_safety(self) -> Dict[str, int]:
166+
"""Analyze safety statistics across the crate."""
167+
stats = {
168+
"total_functions": 0,
169+
"safe_functions": 0,
170+
"unsafe_functions": 0,
171+
"unsafe_blocks": 0,
172+
"ffi_functions": 0,
173+
}
174+
175+
def analyze_module(module: RustModule):
176+
for func in module.functions.values():
177+
stats["total_functions"] += 1
178+
if func.is_fully_safe():
179+
stats["safe_functions"] += 1
180+
else:
181+
stats["unsafe_functions"] += 1
182+
if func.safety_analysis.classification == SafetyClassification.FFI:
183+
stats["ffi_functions"] += 1
184+
stats["unsafe_blocks"] += len(func.safety_analysis.unsafe_blocks)
185+
186+
for submodule in module.submodules.values():
187+
analyze_module(submodule)
188+
189+
analyze_module(self.root_module)
190+
return stats
191+
192+
def get_unsafe_functions(self) -> List[tuple[str, RustCallable]]:
193+
"""Get all unsafe functions in the crate with their module paths."""
194+
unsafe_fns = []
195+
196+
def collect_unsafe(module: RustModule, path: str):
197+
for name, func in module.functions.items():
198+
if not func.is_fully_safe():
199+
full_path = f"{path}::{name}" if path else name
200+
unsafe_fns.append((full_path, func))
201+
202+
for submod_name, submod in module.submodules.items():
203+
new_path = f"{path}::{submod_name}" if path else submod_name
204+
collect_unsafe(submod, new_path)
205+
206+
collect_unsafe(self.root_module, "")
207+
return unsafe_fns

0 commit comments

Comments
 (0)