Skip to content

Commit 218d2b6

Browse files
Fix unix socket cleanup.
1 parent 5c0c1cd commit 218d2b6

File tree

2 files changed

+68
-22
lines changed

2 files changed

+68
-22
lines changed

ext/hyper_ruby/src/lib.rs

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ struct Server {
5151
config: RefCell<ServerConfig>,
5252
work_rx: RefCell<Option<crossbeam_channel::Receiver<RequestWithCompletion>>>,
5353
work_tx: RefCell<Option<Arc<crossbeam_channel::Sender<RequestWithCompletion>>>>,
54+
runtime: RefCell<Option<Arc<tokio::runtime::Runtime>>>,
5455
}
5556

5657
impl Server {
@@ -62,6 +63,7 @@ impl Server {
6263
config: RefCell::new(ServerConfig::new()),
6364
work_rx: RefCell::new(Some(work_rx)),
6465
work_tx: RefCell::new(Some(Arc::new(work_tx))),
66+
runtime: RefCell::new(None),
6567
}
6668
}
6769

@@ -143,10 +145,13 @@ impl Server {
143145
.ok_or_else(|| MagnusError::new(magnus::exception::runtime_error(), "Work channel not initialized"))?
144146
.clone();
145147

146-
let rt = tokio::runtime::Builder::new_multi_thread()
148+
let rt = Arc::new(tokio::runtime::Builder::new_multi_thread()
147149
.enable_all()
148150
.build()
149-
.map_err(|e| MagnusError::new(magnus::exception::runtime_error(), e.to_string()))?;
151+
.map_err(|e| MagnusError::new(magnus::exception::runtime_error(), e.to_string()))?);
152+
153+
// Store the runtime
154+
*self.runtime.borrow_mut() = Some(rt.clone());
150155

151156
rt.block_on(async {
152157
let work_tx = work_tx.clone();
@@ -167,7 +172,7 @@ impl Server {
167172
let incoming = UnixListenerStream::new(listener);
168173
warp::serve(any_route)
169174
.run_incoming(incoming)
170-
.await
175+
.await;
171176
} else {
172177
let addr: SocketAddr = config.bind_address.parse()
173178
.expect("invalid address format");
@@ -183,31 +188,31 @@ impl Server {
183188
Ok::<(), MagnusError>(())
184189
})?;
185190

186-
// Keep the runtime alive
187-
std::thread::spawn(move || {
188-
rt.block_on(async {
189-
loop {
190-
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
191-
}
192-
});
193-
});
194-
195191
Ok(())
196192
}
197193

198194
pub fn stop(&self) -> Result<(), MagnusError> {
199-
let rt = tokio::runtime::Runtime::new()
200-
.map_err(|e| MagnusError::new(magnus::exception::runtime_error(), e.to_string()))?;
201-
202-
rt.block_on(async {
203-
let mut handle = self.server_handle.lock().await;
204-
if let Some(task) = handle.take() {
205-
task.abort();
206-
}
207-
});
195+
// Use the stored runtime instead of creating a new one
196+
if let Some(rt) = self.runtime.borrow().as_ref() {
197+
rt.block_on(async {
198+
let mut handle = self.server_handle.lock().await;
199+
if let Some(task) = handle.take() {
200+
task.abort();
201+
}
202+
});
203+
}
208204

209-
// Drop the channel to signal workers to shut down
205+
// Drop the channel and runtime
210206
self.work_tx.borrow_mut().take();
207+
self.runtime.borrow_mut().take();
208+
209+
let bind_address = self.config.borrow().bind_address.clone();
210+
if bind_address.starts_with("unix:") {
211+
let path = bind_address.trim_start_matches("unix:");
212+
std::fs::remove_file(path).unwrap_or_else(|e| {
213+
println!("Failed to remove socket file: {:?}", e);
214+
});
215+
}
211216

212217
Ok(())
213218
}

test/test_hyper_ruby.rb

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ def test_large_post
3535
end
3636
end
3737

38+
def test_unix_socket_cleans_up_socket
39+
with_unix_socket_server(-> (request) { handler_simple(request) }) do |client|
40+
response = client.get("/")
41+
assert_equal 200, response.status
42+
assert_equal "text/plain", response.headers["content-type"]
43+
assert_equal 'GET', response.body
44+
end
45+
46+
with_unix_socket_server(-> (request) { handler_simple(request) }) do |client|
47+
response = client.get("/")
48+
assert_equal 200, response.status
49+
assert_equal "text/plain", response.headers["content-type"]
50+
assert_equal 'GET', response.body
51+
end
52+
end
53+
3854
# def test_blocking
3955
# with_server(-> (request) { handler_simple(request) }) do |client|
4056
# gets
@@ -65,6 +81,31 @@ def with_server(request_handler, &block)
6581
worker.join if worker
6682
end
6783

84+
def with_unix_socket_server(request_handler, &block)
85+
server = HyperRuby::Server.new
86+
server.configure({ bind_address: "unix:/tmp/hyper_ruby_test.sock" })
87+
server.start
88+
89+
# Create ruby worker threads that process requests;
90+
# 1 is usually enough, and generally handles better than multiple threads
91+
# if there's no IO (because of the GIL)
92+
worker = Thread.new do
93+
server.run_worker do |request|
94+
# Process the request in Ruby
95+
# request is a hash with :method, :path, :headers, and :body keys
96+
request_handler.call(request)
97+
end
98+
end
99+
100+
client = HTTPX.with(transport: "unix", addresses: ["/tmp/hyper_ruby_test.sock"], origin: "http://host")
101+
102+
block.call(client)
103+
104+
ensure
105+
server.stop if server
106+
worker.join if worker
107+
end
108+
68109
def handler_simple(request)
69110
HyperRuby::Response.new(200, { 'Content-Type' => 'text/plain' }, request.http_method)
70111
end

0 commit comments

Comments
 (0)