|
5 | 5 | #include <mp/test/foo.capnp.h> |
6 | 6 | #include <mp/test/foo.capnp.proxy.h> |
7 | 7 |
|
| 8 | +#include <atomic> |
8 | 9 | #include <capnp/capability.h> |
9 | 10 | #include <capnp/rpc.h> |
| 11 | +#include <condition_variable> |
10 | 12 | #include <cstring> |
| 13 | +#include <exception> |
11 | 14 | #include <functional> |
12 | 15 | #include <future> |
13 | 16 | #include <iostream> |
14 | 17 | #include <kj/async.h> |
15 | 18 | #include <kj/async-io.h> |
16 | 19 | #include <kj/common.h> |
17 | 20 | #include <kj/debug.h> |
| 21 | +#include <kj/exception.h> |
18 | 22 | #include <kj/memory.h> |
| 23 | +#include <kj/string.h> |
19 | 24 | #include <kj/test.h> |
20 | 25 | #include <memory> |
21 | 26 | #include <mp/proxy.h> |
| 27 | +#include "mp/proxy.capnp.h" |
22 | 28 | #include <mp/proxy-io.h> |
| 29 | +#include "mp/util.h" |
23 | 30 | #include <optional> |
24 | 31 | #include <set> |
25 | 32 | #include <stdexcept> |
@@ -297,5 +304,71 @@ KJ_TEST("Calling IPC method, disconnecting and blocking during the call") |
297 | 304 | signal.set_value(); |
298 | 305 | } |
299 | 306 |
|
| 307 | +KJ_TEST("Make simultaneous IPC callbacks with same request_thread and callback_thread") |
| 308 | +{ |
| 309 | + TestSetup setup; |
| 310 | + ProxyClient<messages::FooInterface>* foo = setup.client.get(); |
| 311 | + std::promise<void> signal; |
| 312 | + |
| 313 | + foo->initThreadMap(); |
| 314 | + // Use callFnAsync() to get the client to setup the request_thread |
| 315 | + // that will be used for the test. |
| 316 | + setup.server->m_impl->m_fn = [&] {}; |
| 317 | + foo->callFnAsync(); |
| 318 | + ThreadContext& tc{g_thread_context}; |
| 319 | + std::optional<Thread::Client> callback_thread, request_thread; |
| 320 | + { |
| 321 | + Lock lock(tc.waiter->m_mutex); |
| 322 | + callback_thread = tc.callback_threads.at(foo->m_context.connection)->m_client; |
| 323 | + request_thread = tc.request_threads.at(foo->m_context.connection)->m_client; |
| 324 | + } |
| 325 | + |
| 326 | + setup.server->m_impl->m_fn = [&] { |
| 327 | + try |
| 328 | + { |
| 329 | + signal.get_future().get(); |
| 330 | + } |
| 331 | + catch(const std::exception& e) |
| 332 | + { |
| 333 | + KJ_EXPECT(e.what() == std::string("Future already retrieved")); |
| 334 | + } |
| 335 | + }; |
| 336 | + |
| 337 | + auto client{foo->m_client}; |
| 338 | + bool caught_thread_busy = false; |
| 339 | + // NOTE: '3' was choosen because it was the lowest number |
| 340 | + // of simultaneous calls required to reliably catch a "thread busy" error |
| 341 | + std::atomic<size_t> running{3}; |
| 342 | + foo->m_context.loop->sync([&] |
| 343 | + { |
| 344 | + for (size_t i = 0; i < running; i++) |
| 345 | + { |
| 346 | + auto request{client.callFnAsyncRequest()}; |
| 347 | + auto context{request.initContext()}; |
| 348 | + context.setCallbackThread(*callback_thread); |
| 349 | + context.setThread(*request_thread); |
| 350 | + foo->m_context.loop->m_task_set->add(request.send().then( |
| 351 | + [&](auto&& results) { |
| 352 | + running -= 1; |
| 353 | + tc.waiter->m_cv.notify_all(); |
| 354 | + }, |
| 355 | + [&](kj::Exception&& e) { |
| 356 | + KJ_EXPECT(std::string_view{e.getDescription().cStr()} == |
| 357 | + "remote exception: std::exception: thread busy"); |
| 358 | + caught_thread_busy = true; |
| 359 | + running -= 1; |
| 360 | + signal.set_value(); |
| 361 | + tc.waiter->m_cv.notify_all(); |
| 362 | + } |
| 363 | + )); |
| 364 | + } |
| 365 | + }); |
| 366 | + { |
| 367 | + Lock lock(tc.waiter->m_mutex); |
| 368 | + tc.waiter->wait(lock, [&running] { return running == 0; }); |
| 369 | + } |
| 370 | + KJ_EXPECT(caught_thread_busy); |
| 371 | +} |
| 372 | + |
300 | 373 | } // namespace test |
301 | 374 | } // namespace mp |
0 commit comments