Skip to content

Commit 18ca935

Browse files
committed
Improvement: Check ALL foreign function calls for a sync variable
The idea is to perform the check in `link_return_value_if_sync_variable` for all foreign function calls. This saves us having to maintain a list of "supported functions" function from the std where we actually track if a mutex/lock guard/condvar/join handle is being passed around. It also simplifies the number of cases in FunctionCall that we need to handle. The new solution is more robust since it always performs the check, making it "future-proof" in a certain way. In the process, we needed to handle extracting the arguments in a different way that ignores constants and skips the check if the function has zero arguments.
1 parent 8f7872b commit 18ca935

File tree

4 files changed

+76
-44
lines changed

4 files changed

+76
-44
lines changed

src/translator/function_call.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ pub enum FunctionCall {
2121
/// Abridged function call.
2222
/// Non-recursive call for the translation process.
2323
Foreign,
24-
/// Abridged function call that involves a synchronization primitive.
25-
/// Non-recursive call for the translation process.
26-
ForeignWithSyncPrimitive,
2724
/// MIR function call (the "default" case).
2825
/// Recursive call for the translation process.
2926
MirFunction,
@@ -62,11 +59,6 @@ impl FunctionCall {
6259
/// Returns the corresponding variant for the function or `None` otherwise.
6360
fn is_supported_function(function_name: &str) -> Option<Self> {
6461
match function_name {
65-
"std::clone::Clone::clone"
66-
| "std::ops::Deref::deref"
67-
| "std::ops::DerefMut::deref_mut"
68-
| "std::result::Result::<T, E>::unwrap"
69-
| "std::sync::Arc::<T>::new" => Some(Self::ForeignWithSyncPrimitive),
7062
"std::sync::Condvar::new" => Some(Self::CondVarNew),
7163
"std::sync::Condvar::notify_one" => Some(Self::CondVarNotifyOne),
7264
"std::sync::Condvar::wait" => Some(Self::CondVarWait),

src/translator/function_call_handler.rs

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,27 @@ impl<'tcx> Translator<'tcx> {
2222
let function_name = self.tcx.def_path_str(function_def_id);
2323

2424
match function_call {
25-
FunctionCall::ForeignWithSyncPrimitive => {
26-
self.call_foreign_function_with_sync_primitive(
25+
FunctionCall::CondVarNew => {
26+
self.call_condvar_new(&function_name, args, destination, &function_call_places);
27+
}
28+
FunctionCall::CondVarNotifyOne => {
29+
self.call_condvar_notify_one(
2730
&function_name,
2831
args,
2932
destination,
3033
&function_call_places,
3134
);
3235
}
33-
FunctionCall::CondVarNew => {
34-
self.call_condvar_new(&function_name, destination, &function_call_places);
35-
}
36-
FunctionCall::CondVarNotifyOne => {
37-
self.call_condvar_notify_one(&function_name, args, &function_call_places);
38-
}
3936
FunctionCall::CondVarWait => {
4037
self.call_condvar_wait(args, destination, &function_call_places);
4138
}
4239
FunctionCall::Foreign => {
43-
self.call_foreign_function(&function_name, &function_call_places);
40+
self.call_foreign_function(
41+
&function_name,
42+
args,
43+
destination,
44+
&function_call_places,
45+
);
4446
}
4547
FunctionCall::MirFunction => {
4648
let (start_place, end_place, _) = function_call_places;
@@ -52,10 +54,10 @@ impl<'tcx> Translator<'tcx> {
5254
self.call_mutex_lock(&function_name, args, destination, &function_call_places);
5355
}
5456
FunctionCall::MutexNew => {
55-
self.call_mutex_new(&function_name, destination, &function_call_places);
57+
self.call_mutex_new(&function_name, args, destination, &function_call_places);
5658
}
5759
FunctionCall::ThreadJoin => {
58-
self.call_thread_join(&function_name, args, &function_call_places);
60+
self.call_thread_join(&function_name, args, destination, &function_call_places);
5961
}
6062
FunctionCall::ThreadSpawn => {
6163
self.call_thread_spawn(&function_name, args, destination, &function_call_places);
@@ -72,34 +74,25 @@ impl<'tcx> Translator<'tcx> {
7274
/// A separate counter is incremented every time that
7375
/// the function is called to generate a unique label.
7476
///
77+
/// Performs a check to keep track of synchronization primitives.
78+
/// In case the first argument is a mutex, lock guard, join handle or condition variable,
79+
/// it links the first argument of the function to its return value.
80+
///
7581
/// Returns the transition that represents the function call.
7682
fn call_foreign_function(
7783
&mut self,
7884
function_name: &str,
85+
args: &[rustc_middle::mir::Operand<'tcx>],
86+
destination: rustc_middle::mir::Place<'tcx>,
7987
function_call_places: &FunctionPlaces,
8088
) -> TransitionRef {
8189
let index = self.function_counter.get_count(function_name);
8290
self.function_counter.increment(function_name);
83-
call_foreign_function(
91+
let function_transition = call_foreign_function(
8492
function_call_places,
8593
&foreign_call_transition_labels(function_name, index),
8694
&mut self.net,
87-
)
88-
}
89-
90-
/// Handler for the case `FunctionCall::ForeignWithSyncPrimitive`.
91-
/// It is an extension of `call_foreign_function` that performs a check
92-
/// to keep track of synchronization primitives.
93-
/// The goal is to link the first argument of the function to its return value
94-
/// in case the first argument is a mutex, lock guard, join handle or condition variable.
95-
fn call_foreign_function_with_sync_primitive(
96-
&mut self,
97-
function_name: &str,
98-
args: &[rustc_middle::mir::Operand<'tcx>],
99-
destination: rustc_middle::mir::Place<'tcx>,
100-
function_call_places: &FunctionPlaces,
101-
) {
102-
self.call_foreign_function(function_name, function_call_places);
95+
);
10396

10497
let current_function = self.call_stack.peek_mut();
10598
link_return_value_if_sync_variable(
@@ -109,16 +102,19 @@ impl<'tcx> Translator<'tcx> {
109102
current_function.def_id,
110103
self.tcx,
111104
);
105+
106+
function_transition
112107
}
113108

114109
/// Handler for the case `FunctionCall::CondvarNew`.
115110
fn call_condvar_new(
116111
&mut self,
117112
function_name: &str,
113+
args: &[rustc_middle::mir::Operand<'tcx>],
118114
destination: rustc_middle::mir::Place<'tcx>,
119115
function_call_places: &FunctionPlaces,
120116
) {
121-
self.call_foreign_function(function_name, function_call_places);
117+
self.call_foreign_function(function_name, args, destination, function_call_places);
122118

123119
let current_function = self.call_stack.peek_mut();
124120
self.condvar_manager.translate_side_effects_new(
@@ -133,9 +129,11 @@ impl<'tcx> Translator<'tcx> {
133129
&mut self,
134130
function_name: &str,
135131
args: &[rustc_middle::mir::Operand<'tcx>],
132+
destination: rustc_middle::mir::Place<'tcx>,
136133
function_call_places: &FunctionPlaces,
137134
) {
138-
let notify_one_transition = self.call_foreign_function(function_name, function_call_places);
135+
let notify_one_transition =
136+
self.call_foreign_function(function_name, args, destination, function_call_places);
139137

140138
let current_function = self.call_stack.peek_mut();
141139
self.condvar_manager.translate_side_effects_notify_one(
@@ -177,7 +175,7 @@ impl<'tcx> Translator<'tcx> {
177175
function_call_places: &FunctionPlaces,
178176
) {
179177
let transition_function_call =
180-
self.call_foreign_function(function_name, function_call_places);
178+
self.call_foreign_function(function_name, args, destination, function_call_places);
181179

182180
let current_function = self.call_stack.peek_mut();
183181
self.mutex_manager.translate_side_effects_lock(
@@ -193,10 +191,11 @@ impl<'tcx> Translator<'tcx> {
193191
fn call_mutex_new(
194192
&mut self,
195193
function_name: &str,
194+
args: &[rustc_middle::mir::Operand<'tcx>],
196195
destination: rustc_middle::mir::Place<'tcx>,
197196
function_call_places: &FunctionPlaces,
198197
) {
199-
self.call_foreign_function(function_name, function_call_places);
198+
self.call_foreign_function(function_name, args, destination, function_call_places);
200199

201200
let current_function = self.call_stack.peek_mut();
202201
self.mutex_manager.translate_side_effects_new(
@@ -211,10 +210,11 @@ impl<'tcx> Translator<'tcx> {
211210
&mut self,
212211
function_name: &str,
213212
args: &[rustc_middle::mir::Operand<'tcx>],
213+
destination: rustc_middle::mir::Place<'tcx>,
214214
function_call_places: &FunctionPlaces,
215215
) {
216216
let transition_function_call =
217-
self.call_foreign_function(function_name, function_call_places);
217+
self.call_foreign_function(function_name, args, destination, function_call_places);
218218

219219
let current_function = self.call_stack.peek();
220220
self.thread_manager.translate_side_effects_join(
@@ -233,7 +233,7 @@ impl<'tcx> Translator<'tcx> {
233233
function_call_places: &FunctionPlaces,
234234
) {
235235
let transition_function_call =
236-
self.call_foreign_function(function_name, function_call_places);
236+
self.call_foreign_function(function_name, args, destination, function_call_places);
237237

238238
let current_function = self.call_stack.peek_mut();
239239
self.thread_manager.translate_side_effects_spawn(

src/translator/sync.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ mod thread;
1111
mod thread_manager;
1212

1313
use crate::translator::mir_function::Memory;
14-
use crate::utils::{check_substring_in_place_type, extract_nth_argument};
14+
use crate::utils::{check_substring_in_place_type, extract_nth_argument_as_place};
1515

1616
pub use condvar::Condvar;
1717
pub use condvar_manager::{CondvarManager, CondvarRef};
@@ -145,6 +145,17 @@ fn generalized_link_place_if_sync_variable<'tcx>(
145145
/// Checks if the first argument for a function call contains a mutex, a lock guard,
146146
/// a join handle or a condition variable, i.e. a synchronization variable.
147147
/// If the first argument contains a synchronization variable, links it to the return value.
148+
/// If there is no first argument or it is a constant,
149+
/// then there is nothing to check, therefore the function simply returns.
150+
///
151+
/// Why check only the first argument?
152+
/// Because most function in the standard library involving synchronization primitives
153+
/// receive it through the first argument. For instance:
154+
/// * `std::clone::Clone::clone`
155+
/// * `std::ops::Deref::deref`
156+
/// * `std::ops::DerefMut::deref_mut`
157+
/// * `std::result::Result::<T, E>::unwrap`
158+
/// * `std::sync::Arc::<T>::new`
148159
///
149160
/// Receives a reference to the memory of the caller function to
150161
/// link the return local variable to the synchronization variable.
@@ -155,7 +166,10 @@ pub fn link_return_value_if_sync_variable<'tcx>(
155166
caller_function_def_id: rustc_hir::def_id::DefId,
156167
tcx: rustc_middle::ty::TyCtxt<'tcx>,
157168
) {
158-
let first_argument = extract_nth_argument(args, 0);
169+
let Some(first_argument) = extract_nth_argument_as_place(args, 0) else {
170+
// Nothing to check: Either the first argument is not present or it is a constant.
171+
return;
172+
};
159173
link_if_sync_variable(
160174
&return_value,
161175
&first_argument,

src/utils.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,32 @@ pub fn extract_nth_argument<'tcx>(
6969
}
7070
}
7171

72+
/// Extracts the n-th argument from the arguments for the function call.
73+
/// Returns the place corresponding to that argument.
74+
///
75+
/// This is also useful for obtaining the self reference for method calls.
76+
/// For example: The call `mutex.lock()` desugars to `std::sync::Mutex::lock(&mutex)`
77+
/// where `&self` is the first argument.
78+
///
79+
/// If the argument can not be found (the array is shorter than the `index` argument)
80+
/// or the argument is a constant (which does not have a `Place` representation),
81+
/// then the function returns `None`.
82+
pub fn extract_nth_argument_as_place<'tcx>(
83+
args: &[rustc_middle::mir::Operand<'tcx>],
84+
index: usize,
85+
) -> Option<rustc_middle::mir::Place<'tcx>> {
86+
let Some(operand) = args.get(index) else {
87+
return None;
88+
};
89+
90+
match operand {
91+
rustc_middle::mir::Operand::Move(place) | rustc_middle::mir::Operand::Copy(place) => {
92+
Some(*place)
93+
}
94+
rustc_middle::mir::Operand::Constant(_) => None,
95+
}
96+
}
97+
7298
/// Extracts the closure passed as the 0-th argument to `std::thread::spawn`.
7399
/// Returns the place corresponding to that argument.
74100
///

0 commit comments

Comments
 (0)