Skip to content

Commit d5fcbee

Browse files
authored
Add timeout for distributed tests. (dmlc#10315)
1 parent b8a7773 commit d5fcbee

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

src/collective/coll.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
191191
for (std::int32_t r = 0; r < comm->World(); ++r) {
192192
auto as_bytes = sizes[r];
193193
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
194-
ncclInt8, r, comm->Handle(), dh::DefaultStream());
194+
ncclInt8, r, comm->Handle(), comm->Stream());
195195
if (!rc.OK()) {
196196
return rc;
197197
}

tests/cpp/collective/test_worker.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ inline auto MakeDistributedTestConfig(std::string host, std::int32_t port,
147147
}
148148

149149
template <typename WorkerFn>
150-
void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true) {
150+
void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true,
151+
std::chrono::seconds test_timeout = std::chrono::seconds{30}) {
151152
system::SocketStartup();
152153
std::chrono::seconds timeout{1};
153154

@@ -163,12 +164,17 @@ void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need
163164

164165
for (std::int32_t i = 0; i < n_workers; ++i) {
165166
workers.emplace_back([=] {
166-
auto config = MakeDistributedTestConfig(host, port, timeout, i);
167-
Init(config);
168-
worker_fn();
169-
if (need_finalize) {
170-
Finalize();
171-
}
167+
auto fut = std::async(std::launch::async, [=] {
168+
auto config = MakeDistributedTestConfig(host, port, timeout, i);
169+
Init(config);
170+
worker_fn();
171+
if (need_finalize) {
172+
Finalize();
173+
}
174+
});
175+
auto status = fut.wait_for(test_timeout);
176+
CHECK(status == std::future_status::ready) << "Test timeout";
177+
fut.get();
172178
});
173179
}
174180

0 commit comments

Comments
 (0)