Skip to content
203 changes: 51 additions & 152 deletions tiledb/common/thread_pool/test/unit_thread_pool.cc

Large diffs are not rendered by default.

31 changes: 0 additions & 31 deletions tiledb/common/thread_pool/test/unit_thread_pool.h

This file was deleted.

74 changes: 10 additions & 64 deletions tiledb/common/thread_pool/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,30 +121,18 @@ void ThreadPool::shutdown() {
threads_.clear();
}

Status ThreadPool::wait_all(std::vector<Task>& tasks) {
auto statuses = wait_all_status(tasks);
for (auto& st : statuses) {
if (!st.ok()) {
return st;
}
}
return Status::Ok();
}

// Return a vector of Status. If any task returns an error value or throws an
// exception, we save an error code in the corresponding location in the Status
// vector. All tasks are waited on before return. Multiple error statuses may
// be saved. We may call logger here because thread pool will not be used until
// context is fully constructed (which will include logger).
// Unfortunately, C++ does not have the notion of an aggregate exception, so we
// don't throw in the case of errors/exceptions.
std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
std::vector<Status> statuses(tasks.size());

void ThreadPool::wait_all(std::vector<Task>& tasks) {
std::queue<size_t> pending_tasks;

// Create queue of ids of all the pending tasks for processing
for (size_t i = 0; i < statuses.size(); ++i) {
for (size_t i = 0; i < tasks.size(); ++i) {
pending_tasks.push(i);
}

Expand All @@ -155,33 +143,12 @@ std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
auto& task = tasks[task_id];

if (!task.valid()) {
statuses[task_id] = Status_ThreadPoolError("Invalid task future");
LOG_STATUS_NO_RETURN_VALUE(statuses[task_id]);
throw TaskException("Invalid task future");
} else if (
task.wait_for(std::chrono::milliseconds(0)) ==
std::future_status::ready) {
// Task is completed, get result, handling possible exceptions

Status st = [&task] {
try {
return task.get();
} catch (const std::exception& e) {
return Status_TaskError(
"Caught std::exception: " + std::string(e.what()));
} catch (const std::string& msg) {
return Status_TaskError("Caught msg: " + msg);
} catch (const Status& stat) {
return stat;
} catch (...) {
return Status_TaskError("Unknown exception");
}
}();

if (!st.ok()) {
LOG_STATUS_NO_RETURN_VALUE(st);
}
statuses[task_id] = st;

// Task is completed, throw possible exception
task.get();
} else {
// If the task is not completed, try again later
pending_tasks.push(task_id);
Expand All @@ -201,39 +168,18 @@ std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
}
}
}

return statuses;
}

Status ThreadPool::wait(Task& task) {
void ThreadPool::wait(Task& task) {
while (true) {
if (!task.valid()) {
return Status_ThreadPoolError("Invalid task future");
throw TaskException("Invalid task future");
} else if (
task.wait_for(std::chrono::milliseconds(0)) ==
std::future_status::ready) {
// Task is completed, get result, handling possible exceptions

Status st = [&task] {
try {
return task.get();
} catch (const std::exception& e) {
return Status_TaskError(
"Caught std::exception: " + std::string(e.what()));
} catch (const std::string& msg) {
return Status_TaskError("Caught msg: " + msg);
} catch (const Status& stat) {
return stat;
} catch (...) {
return Status_TaskError("Unknown exception");
}
}();

if (!st.ok()) {
LOG_STATUS_NO_RETURN_VALUE(st);
}

return st;
// Task is completed, throw possible exception
task.get();
return;
} else {
// In the meantime, try to do something useful to make progress (and avoid
// deadlock)
Expand Down
65 changes: 27 additions & 38 deletions tiledb/common/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,27 @@

#include "producer_consumer_queue.h"

#include <concepts>
#include <functional>
#include <future>

#include "tiledb/common/common.h"
#include "tiledb/common/logger_public.h"
#include "tiledb/common/macros.h"
#include "tiledb/common/status.h"

namespace tiledb::common {

/** Class for Task status exceptions. */
class TaskException : public StatusException {
public:
explicit TaskException(const std::string& msg)
: StatusException("Task", msg) {
}
};

class ThreadPool {
public:
using Task = std::future<Status>;
using Task = std::future<void>;

/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
Expand Down Expand Up @@ -92,26 +100,19 @@ class ThreadPool {
*/

template <class Fn, class... Args>
auto async(Fn&& f, Args&&... args) {
Task async(Fn&& f, Args&&... args)
requires std::same_as<std::invoke_result_t<Fn, std::decay_t<Args>...>, void>
{
if (concurrency_level_ == 0) {
Task invalid_future;
LOG_ERROR("Cannot execute task; thread pool uninitialized.");
return invalid_future;
}

using R = std::invoke_result_t<std::decay_t<Fn>, std::decay_t<Args>...>;

auto task = make_shared<std::packaged_task<R()>>(
HERE(),
[f = std::forward<Fn>(f),
args = std::make_tuple(std::forward<Args>(args)...)]() mutable {
return std::apply(std::move(f), std::move(args));
});

std::future<R> future = task->get_future();

auto task = make_shared<std::packaged_task<void()>>(
HERE(), std::bind(std::forward<Fn>(f), std::forward<Args>(args)...));
auto future = task->get_future();
task_queue_.push(task);

return future;
}

Expand All @@ -123,7 +124,9 @@ class ThreadPool {
* @return std::future referring to the shared state created by this call
*/
template <class Fn, class... Args>
auto execute(Fn&& f, Args&&... args) {
Task execute(Fn&& f, Args&&... args)
requires std::same_as<std::invoke_result_t<Fn, std::decay_t<Args>...>, void>
{
return async(std::forward<Fn>(f), std::forward<Args>(args)...);
}

Expand All @@ -133,36 +136,22 @@ class ThreadPool {
* waiting.
*
* @param tasks Task list to wait on.
* @return Status::Ok if all tasks returned Status::Ok, otherwise the first
* error status is returned
*/
Status wait_all(std::vector<Task>& tasks);

/**
* Wait on all the given tasks to complete, returning a vector of their return
* Status. Exceptions caught while waiting are returned as Status_TaskError.
* Status are saved at the same index in the return vector as the
* corresponding task in the input vector. The status vector may contain more
* than one error Status.
*
* This function is safe to call recursively and may execute pending tasks
* with the calling thread while waiting.
*
* @param tasks Task list to wait on
* @return Vector of each task's Status.
* @throws This function will throw the first exception thrown by one of the
* tasks.
*/
std::vector<Status> wait_all_status(std::vector<Task>& tasks);
void wait_all(std::vector<Task>& tasks);

/**
* Wait on a single tasks to complete. This function is safe to call
* recursively and may execute pending tasks on the calling thread while
* waiting.
*
* @param task Task to wait on.
* @return Status::Ok if the task returned Status::Ok, otherwise the error
* status is returned
*
* @throws This function will throw the exception thrown by task.
*/
Status wait(Task& task);
void wait(Task& task);

/* ********************************* */
/* PRIVATE ATTRIBUTES */
Expand All @@ -177,8 +166,8 @@ class ThreadPool {

/** Producer-consumer queue where functions to be executed are kept */
ProducerConsumerQueue<
shared_ptr<std::packaged_task<Status()>>,
std::deque<shared_ptr<std::packaged_task<Status()>>>>
shared_ptr<std::packaged_task<void()>>,
std::deque<shared_ptr<std::packaged_task<void()>>>>
task_queue_;

/** The worker threads */
Expand Down
44 changes: 18 additions & 26 deletions tiledb/sm/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,16 +679,14 @@ void Array::delete_fragments(

// Delete fragments and commits
auto vfs = &(resources.vfs());
throw_if_not_ok(parallel_for(
&resources.compute_tp(), 0, fragment_uris.size(), [&](size_t i) {
throw_if_not_ok(vfs->remove_dir(fragment_uris[i].uri_));
bool is_file = false;
throw_if_not_ok(vfs->is_file(commit_uris_to_delete[i], &is_file));
if (is_file) {
throw_if_not_ok(vfs->remove_file(commit_uris_to_delete[i]));
}
return Status::Ok();
}));
parallel_for(&resources.compute_tp(), 0, fragment_uris.size(), [&](size_t i) {
throw_if_not_ok(vfs->remove_dir(fragment_uris[i].uri_));
bool is_file = false;
throw_if_not_ok(vfs->is_file(commit_uris_to_delete[i], &is_file));
if (is_file) {
throw_if_not_ok(vfs->remove_file(commit_uris_to_delete[i]));
}
});
}

void Array::delete_fragments(
Expand Down Expand Up @@ -1711,7 +1709,7 @@ std::unordered_map<std::string, uint64_t> Array::get_average_var_cell_sizes()

// Load all metadata for tile var sizes among fragments.
for (const auto& var_name : var_names) {
throw_if_not_ok(parallel_for(
parallel_for(
&resources_.compute_tp(),
0,
fragment_metadata.size(),
Expand All @@ -1720,17 +1718,16 @@ std::unordered_map<std::string, uint64_t> Array::get_average_var_cell_sizes()
// evolution that do not exists in this fragment.
const auto& schema = fragment_metadata[f]->array_schema();
if (!schema->is_field(var_name)) {
return Status::Ok();
return;
}

fragment_metadata[f]->loaded_metadata()->load_tile_var_sizes(
*encryption_key(), var_name);
return Status::Ok();
}));
});
}

// Now compute for each var size names, the average cell size.
throw_if_not_ok(parallel_for(
parallel_for(
&resources_.compute_tp(), 0, var_names.size(), [&](const uint64_t n) {
uint64_t total_size = 0;
uint64_t cell_num = 0;
Expand All @@ -1756,9 +1753,7 @@ std::unordered_map<std::string, uint64_t> Array::get_average_var_cell_sizes()

uint64_t average_cell_size = total_size / cell_num;
ret[var_name] = std::max<uint64_t>(average_cell_size, 1);

return Status::Ok();
}));
});

return ret;
}
Expand Down Expand Up @@ -1988,15 +1983,12 @@ void Array::do_load_metadata() {

auto metadata_num = array_metadata_to_load.size();
std::vector<shared_ptr<Tile>> metadata_tiles(metadata_num);
throw_if_not_ok(
parallel_for(&resources_.compute_tp(), 0, metadata_num, [&](size_t m) {
const auto& uri = array_metadata_to_load[m].uri_;

metadata_tiles[m] = GenericTileIO::load(
resources_, uri, 0, *encryption_key(), memory_tracker_);
parallel_for(&resources_.compute_tp(), 0, metadata_num, [&](size_t m) {
const auto& uri = array_metadata_to_load[m].uri_;

return Status::Ok();
}));
metadata_tiles[m] = GenericTileIO::load(
resources_, uri, 0, *encryption_key(), memory_tracker_);
});

// Compute array metadata size for the statistics
uint64_t meta_size = 0;
Expand Down
Loading