Skip to content

Commit 56dc8f3

Browse files
committed
Refactor handle shutdown
1 parent d324f9f commit 56dc8f3

File tree

6 files changed

+45
-55
lines changed

6 files changed

+45
-55
lines changed

CHANGELOG.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ and this project adheres to
1212

1313
- Add `Drop` implementation for `WinDivert` to automatically close the handle
1414
when it goes out of scope.
15+
- Add `ShutdownHandle` struct to improve multithreaded handle shutdown
16+
ergonomics.
17+
- Add `WinDivert::shutdown_handle(&self)` to create a `ShutdownHandle` for the
18+
current instance.
1519

1620
### Changed
1721

@@ -21,8 +25,14 @@ and this project adheres to
2125
- Bump `thiserror` to 2.0
2226
- Changed `WinDivert::shutdown()` method to use a shared reference instead of a
2327
mutable reference. (#16)
24-
- Changed `WinDivert::close()` method to be consuming instead of using a mutable
25-
reference. (#15)
28+
- Changed `WinDivert::close()` method to be consuming, and remove it's `action`
29+
parameter.
30+
31+
### Removed
32+
33+
- Removed `CloseAction`
34+
- Removed `Windivert::shutdown()` method in favor of
35+
`Windivert::shutdown_handle()`.
2636

2737
## [Unreleased-sys]
2838

windivert/src/core/winapi/mutex.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ impl InstallMutex {
1717
Ok(Self { handle })
1818
}
1919

20-
pub fn lock(&mut self) -> Result<InstallMutexGuard, windows::core::Error> {
20+
pub fn lock(&mut self) -> Result<InstallMutexGuard<'_>, windows::core::Error> {
2121
unsafe {
2222
match WaitForSingleObject(self.handle, INFINITE) {
2323
WAIT_ABANDONED | WAIT_OBJECT_0 => {}

windivert/src/divert/mod.rs

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::{
44
marker::PhantomData,
55
mem::MaybeUninit,
66
path::Path,
7+
sync::{Arc, Weak},
78
};
89

910
use windows::{
@@ -44,11 +45,10 @@ mod socket;
4445
#[non_exhaustive]
4546
#[derive(Debug)]
4647
pub struct WinDivert<L: layer::WinDivertLayerTrait> {
47-
handle: HANDLE,
48+
handle: Arc<HANDLE>,
4849
tls_index: TlsIndex,
4950
core: SysWrapper,
5051
_layer: PhantomData<L>,
51-
_is_closed: bool,
5252
}
5353

5454
unsafe impl<L: layer::WinDivertLayerTrait> Send for WinDivert<L> {}
@@ -74,11 +74,10 @@ impl<L: layer::WinDivertLayerTrait> WinDivert<L> {
7474
Err(open_err.into())
7575
} else {
7676
Ok(Self {
77-
handle,
77+
handle: Arc::new(handle),
7878
tls_index: windivert_tls_idx,
7979
core: sys_wrapper,
8080
_layer: PhantomData::<L>,
81-
_is_closed: false,
8281
})
8382
}
8483
}
@@ -348,24 +347,14 @@ impl<L: layer::WinDivertLayerTrait> WinDivert<L> {
348347
}
349348

350349
/// Handle close function.
351-
pub fn close(mut self, action: CloseAction) -> Result<(), WinDivertError> {
352-
self.inner_close(action)
353-
}
350+
pub fn close(self) {}
354351

355-
/// Handle close function (internally, non-consuming).
356-
fn inner_close(&mut self, action: CloseAction) -> Result<(), WinDivertError> {
357-
self._is_closed = true;
358-
unsafe { BOOL(WinDivertClose(self.handle.0)) }.ok()?;
359-
match action {
360-
CloseAction::Uninstall => WinDivert::uninstall(),
361-
CloseAction::Nothing => Ok(()),
352+
/// Returns a new ShutdownHandle that can be used to remotely shutdown WinDivert handle
353+
pub fn shutdown_handle(&self) -> ShutdownHandle {
354+
ShutdownHandle {
355+
handle: Arc::downgrade(&self.handle),
362356
}
363357
}
364-
365-
/// Shutdown function.
366-
pub fn shutdown(&self, mode: WinDivertShutdownMode) -> Result<(), WinDivertError> {
367-
Ok(unsafe { BOOL(WinDivertShutdown(self.handle.0, mode)) }.ok()?)
368-
}
369358
}
370359

371360
/// Utility methods for WinDivert.
@@ -404,27 +393,24 @@ impl WinDivert<()> {
404393

405394
impl<L: layer::WinDivertLayerTrait> Drop for WinDivert<L> {
406395
fn drop(&mut self) {
407-
if !self._is_closed {
408-
// SAFETY: Internal close should only fail if:
409-
// * Handle is closed: Checked
410-
// * Handle is invalid: Impossible with current API
411-
// * Permission issues: Impossible due to admin required for open
412-
// It's safe to ignore the return value
413-
let _ = self.inner_close(CloseAction::Nothing);
414-
}
396+
let _ = unsafe { BOOL(WinDivertClose(self.handle.0)) }.ok();
415397
}
416398
}
417399

418-
/// Action parameter for [`WinDivert::close()`](`fn@WinDivert::close`)
419-
pub enum CloseAction {
420-
/// Close the handle and try to uninstall the WinDivert driver.
421-
Uninstall,
422-
/// Close the handle without uninstalling the driver.
423-
Nothing,
400+
/// Struct that allows remote signaling the shutdown on the associated handle
401+
#[derive(Debug)]
402+
pub struct ShutdownHandle {
403+
handle: Weak<HANDLE>,
424404
}
425405

426-
impl Default for CloseAction {
427-
fn default() -> Self {
428-
Self::Nothing
406+
impl ShutdownHandle {
407+
/// Shuts down the associated handle
408+
/// This will prevent any further send, as well as stopping ongoing recv
409+
/// Ongoing recv operations might not end immediately, all queued events/packets from the driver to the associated handle must be exhausted before reaching the `WinDivertRecvError::NoData` result.
410+
pub fn shutdown(&self) -> Result<(), WinDivertError> {
411+
if let Some(handle) = self.handle.upgrade() {
412+
unsafe { BOOL(WinDivertShutdown(handle.0, WinDivertShutdownMode::Both)) }.ok()?;
413+
};
414+
Ok(())
429415
}
430416
}

windivert/src/divert/network.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,10 @@ mod tests {
125125

126126
fn setup_divert(sys_wrapper: SysWrapper) -> WinDivert<NetworkLayer> {
127127
WinDivert {
128-
handle: HANDLE(1usize as *mut c_void),
128+
handle: Arc::new(HANDLE(1usize as *mut c_void)),
129129
tls_index: TlsIndex::alloc_tls().unwrap(),
130130
core: sys_wrapper,
131131
_layer: PhantomData::<NetworkLayer>,
132-
_is_closed: false
133132
}
134133
}
135134

windivert/src/lib.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,43 +16,38 @@ let Ok(divert) = WinDivert::network("ip and tcp.DstPort == 443", 0, Default::def
1616
panic!("Failed to create WinDivert");
1717
};
1818
19-
let divert = std::sync::Arc::new(divert);
19+
let shutdown_handle = divert.shutdown_handle();
2020
21-
let divert_shared = divert.clone();
2221
let handle = std::thread::spawn(move || {
2322
// Do something in the background
2423
let mut buffer = [0u8; 1500];
2524
2625
loop {
27-
match divert_shared.recv(&mut buffer) {
26+
match divert.recv(&mut buffer) {
2827
Ok(packet) => {
2928
// In capture mode the packet is captured and not calling `send()` with it will prevent it from reaching the destination.
30-
divert_shared.send(&packet).expect("Failed to send packet");
31-
},
29+
divert.send(&packet).expect("Failed to send packet");
30+
}
3231
Err(WinDivertError::Recv(WinDivertRecvError::NoData)) => {
3332
// Handle was shutdown, and there is no more pending data to receive
3433
break;
35-
},
34+
}
3635
Err(e) => {
3736
// Other errors
3837
eprintln!("Error receiving packet: {}", e);
3938
}
4039
}
4140
}
41+
// The handle is implicitly closed once `divert` is dropped
4242
});
4343
4444
std::thread::sleep(std::time::Duration::from_secs(10));
4545
46-
divert
47-
.shutdown(WinDivertShutdownMode::Both)
46+
shutdown_handle
47+
.shutdown()
4848
.expect("Failed to shutdown WinDivert");
4949
5050
handle.join().unwrap();
51-
52-
std::sync::Arc::try_unwrap(divert)
53-
.expect("Thread already finished, no references remaining")
54-
.close(CloseAction::Nothing)
55-
.expect("Failed to close WinDivert");
5651
```
5752
*/
5853

windivert/src/utils.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::borrow::Cow;
22

3-
pub(crate) fn prepare_internet_slice_data(slice: &[u8]) -> (&[u8], Cow<[u8]>) {
3+
pub(crate) fn prepare_internet_slice_data(slice: &[u8]) -> (&[u8], Cow<'_, [u8]>) {
44
let headers = etherparse::SlicedPacket::from_ip(slice)
55
.expect("WinDivert can't capture anything below ip");
66
let offset = match headers.net.unwrap() {

0 commit comments

Comments
 (0)