Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 59 additions & 21 deletions libs/core/include/cuda-qx/core/extension_point.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************-*- C++ -*-****
* Copyright (c) 2024 NVIDIA Corporation & Affiliates. *
* Copyright (c) 2024 - 2025 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
Expand All @@ -10,6 +10,7 @@

#include <functional>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <unordered_map>

Expand Down Expand Up @@ -88,7 +89,9 @@ class extension_point {
/// @return A reference to the static registry map.
/// See INSTANTIATE_REGISTRY() macros below for sample implementations that
/// need to be included in C++ source files.
static std::unordered_map<std::string, CreatorFunction> &get_registry();
static std::pair<std::recursive_mutex &,
std::unordered_map<std::string, CreatorFunction> &>
get_registry();

public:
/// @brief Create an instance of a registered extension.
Expand All @@ -97,7 +100,8 @@ class extension_point {
/// @return A unique pointer to the created instance.
/// @throws std::runtime_error if the extension is not found.
static std::unique_ptr<T> get(const std::string &name, CtorArgs... args) {
auto &registry = get_registry();
auto [mutex, registry] = get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
auto iter = registry.find(name);
if (iter == registry.end())
throw std::runtime_error("Cannot find extension with name = " + name);
Expand All @@ -109,7 +113,8 @@ class extension_point {
/// @return A vector of registered extension names.
static std::vector<std::string> get_registered() {
std::vector<std::string> names;
auto &registry = get_registry();
auto [mutex, registry] = get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
for (auto &[k, v] : registry)
names.push_back(k);
return names;
Expand All @@ -119,17 +124,29 @@ class extension_point {
/// @param name The identifier of the extension to check.
/// @return True if the extension is registered, false otherwise.
static bool is_registered(const std::string &name) {
auto &registry = get_registry();
auto [mutex, registry] = get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
return registry.find(name) != registry.end();
}

/// @brief Unregister an extension.
/// @param name The identifier of the extension to unregister.
static void unregister(const std::string &name) {
auto [mutex, registry] = get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
auto iter = registry.find(name);
if (iter != registry.end())
registry.erase(iter);
}
};

/// @brief Macro for defining a creator function for an extension.
/// @param BASE The base class of the extension.
/// @param TYPE The derived class implementing the extension.
#define CUDAQ_EXTENSION_CREATOR_FUNCTION(BASE, TYPE) \
static inline bool register_type() { \
auto &registry = get_registry(); \
auto [mutex, registry] = get_registry(); \
std::lock_guard<std::recursive_mutex> lock(mutex); \
registry[TYPE::class_identifier] = TYPE::create; \
return true; \
} \
Expand All @@ -142,7 +159,8 @@ class extension_point {
/// @param ... Custom implementation of the create function.
#define CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(TYPE, ...) \
static inline bool register_type() { \
auto &registry = get_registry(); \
auto [mutex, registry] = get_registry(); \
std::lock_guard<std::recursive_mutex> lock(mutex); \
registry[TYPE::class_identifier] = TYPE::create; \
return true; \
} \
Expand All @@ -152,7 +170,8 @@ class extension_point {

#define CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION_WITH_NAME(TYPE, NAME, ...) \
static inline bool register_type() { \
auto &registry = TYPE::get_registry(); \
auto [mutex, registry] = TYPE::get_registry(); \
std::lock_guard<std::recursive_mutex> lock(mutex); \
registry.insert({NAME, TYPE::create}); \
return true; \
} \
Expand All @@ -163,7 +182,13 @@ class extension_point {
/// @brief Macro for registering an extension type.
/// @param TYPE The class to be registered as an extension.
#define CUDAQ_REGISTER_TYPE(TYPE) \
const bool TYPE::registered_ = TYPE::register_type();
const bool TYPE::registered_ = TYPE::register_type(); \
/* We must ALSO provide a destructor to clean up the registry so that when a \
* dlcose happens, the parent registry no longer holds references to code \
* that has been unloaded. */ \
__attribute__((destructor)) void cudaq_extension_point_cleanup_##TYPE() { \
TYPE::unregister(TYPE::class_identifier); \
}

/// In order to support building CUDA-QX libraries with g++ and building
/// application code with nvq++ (which uses clang++ under the hood), you must
Expand All @@ -173,30 +198,43 @@ class extension_point {
///
/// Use this version of the helper macro if the only template argument to
/// extension_point<> is the derived class (with no additional creator args).
///
/// Similar to cudaq::qec::get_plugin_handles(), we must create the static mutex
/// and registry as pointers to avoid issues with destructor ordering when
/// performing cleanup operations (like when one dlclose's a library). This
/// creates a small memory leak, but it prevents bigger problems.
#define INSTANTIATE_REGISTRY_NO_ARGS(FULL_TYPE_NAME) \
template <> \
std::unordered_map<std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>()>> & \
std::pair< \
std::recursive_mutex &, \
std::unordered_map<std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>()>> &> \
cudaqx::extension_point<FULL_TYPE_NAME>::get_registry() { \
static std::recursive_mutex *mutex = new std::recursive_mutex(); \
static std::unordered_map< \
std::string, std::function<std::unique_ptr<FULL_TYPE_NAME>()>> \
registry; \
return registry; \
*registry = new std::unordered_map< \
std::string, std::function<std::unique_ptr<FULL_TYPE_NAME>()>>(); \
return {*mutex, *registry}; \
}

/// Use this variadic version of the helper macro if there are additional
/// arguments for the creator function.
#define INSTANTIATE_REGISTRY(FULL_TYPE_NAME, ...) \
template <> \
std::unordered_map< \
std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>(__VA_ARGS__)>> & \
std::pair<std::recursive_mutex &, \
std::unordered_map<std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>( \
__VA_ARGS__)>> &> \
cudaqx::extension_point<FULL_TYPE_NAME, __VA_ARGS__>::get_registry() { \
static std::unordered_map< \
std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>(__VA_ARGS__)>> \
registry; \
return registry; \
static std::recursive_mutex *mutex = new std::recursive_mutex(); \
static std::unordered_map<std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>( \
__VA_ARGS__)>> *registry = \
new std::unordered_map< \
std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>(__VA_ARGS__)>>(); \
return {*mutex, *registry}; \
}

} // namespace cudaqx
9 changes: 6 additions & 3 deletions libs/core/include/cuda-qx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class tensor_impl
/// invalid
static std::unique_ptr<tensor_impl<Scalar>>
get(const std::string &name, const std::vector<std::size_t> &shape) {
auto &registry = BaseExtensionPoint::get_registry();
auto [mutex, registry] = BaseExtensionPoint::get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
auto iter = registry.find(name);
if (iter == registry.end())
throw std::runtime_error("invalid tensor_impl requested: " + name);
Expand All @@ -58,7 +59,8 @@ class tensor_impl
/// invalid
static std::unique_ptr<tensor_impl<Scalar>>
get(const std::string &name, const std::vector<std::string> &data) {
auto &registry = BaseExtensionPoint::get_registry();
auto [mutex, registry] = BaseExtensionPoint::get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
auto iter = registry.find(name);
if (iter == registry.end())
throw std::runtime_error("invalid tensor_impl requested: " + name);
Expand Down Expand Up @@ -92,7 +94,8 @@ class tensor_impl
static std::unique_ptr<tensor_impl<Scalar>>
get(const std::string &name, const scalar_type *data,
const std::vector<std::size_t> &shape) {
auto &registry = BaseExtensionPoint::get_registry();
auto [mutex, registry] = BaseExtensionPoint::get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
auto iter = registry.find(name);
if (iter == registry.end())
throw std::runtime_error("invalid tensor_impl requested: " + name);
Expand Down
6 changes: 4 additions & 2 deletions libs/qec/lib/code.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ std::unique_ptr<code>
code::get(const std::string &name,
const std::vector<cudaq::spin_op_term> &_stabilizers,
const heterogeneous_map options) {
auto &registry = get_registry();
auto [mutex, registry] = get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
auto iter = registry.find(name);
if (iter == registry.end())
throw std::runtime_error("invalid qec_code requested: " + name);
Expand All @@ -28,7 +29,8 @@ code::get(const std::string &name,

std::unique_ptr<code> code::get(const std::string &name,
const heterogeneous_map options) {
auto &registry = get_registry();
auto [mutex, registry] = get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
auto iter = registry.find(name);
if (iter == registry.end())
throw std::runtime_error("invalid qec_code requested: " + name);
Expand Down
3 changes: 2 additions & 1 deletion libs/qec/lib/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ decoder::decode_async(const std::vector<float_t> &syndrome) {
std::unique_ptr<decoder>
decoder::get(const std::string &name, const cudaqx::tensor<uint8_t> &H,
const cudaqx::heterogeneous_map &param_map) {
auto &registry = get_registry();
auto [mutex, registry] = get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
auto iter = registry.find(name);
if (iter == registry.end())
throw std::runtime_error(
Expand Down
2 changes: 1 addition & 1 deletion libs/qec/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ add_dependencies(CUDAQXQECUnitTests test_decoders_yaml)
gtest_discover_tests(test_decoders_yaml)

add_executable(test_qec test_qec.cpp)
target_link_libraries(test_qec PRIVATE GTest::gtest_main cudaq-qec cudaq::cudaq)
target_link_libraries(test_qec PRIVATE GTest::gtest_main cudaq-qec cudaq::cudaq-stim-target)
add_dependencies(CUDAQXQECUnitTests test_qec)
gtest_discover_tests(test_qec)

Expand Down
3 changes: 2 additions & 1 deletion libs/solvers/include/cudaq/solvers/observe_gradient.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ class observe_gradient
static std::unique_ptr<observe_gradient>
get(const std::string &name, NonStdKernel &&kernel, const spin_op &op,
ArgTranslator &&translator) {
auto &registry = get_registry();
auto [mutex, registry] = get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
auto iter = registry.find(name);
if (iter == registry.end())
throw std::runtime_error("Cannot find extension with name = " + name);
Expand Down