diff --git a/projects/rocprofiler-sdk/source/lib/common/container/CMakeLists.txt b/projects/rocprofiler-sdk/source/lib/common/container/CMakeLists.txt index f1ab95775fd..69f82447d9d 100644 --- a/projects/rocprofiler-sdk/source/lib/common/container/CMakeLists.txt +++ b/projects/rocprofiler-sdk/source/lib/common/container/CMakeLists.txt @@ -2,8 +2,16 @@ # add container sources and headers to common library target # set(containers_headers - ring_buffer.hpp c_array.hpp operators.hpp record_header_buffer.hpp ring_buffer.hpp - small_vector.hpp stable_vector.hpp static_vector.hpp) + ring_buffer.hpp + c_array.hpp + operators.hpp + pool.hpp + pool_object.hpp + record_header_buffer.hpp + ring_buffer.hpp + small_vector.hpp + stable_vector.hpp + static_vector.hpp) set(containers_sources ring_buffer.cpp record_header_buffer.cpp ring_buffer.cpp small_vector.cpp) diff --git a/projects/rocprofiler-sdk/source/lib/common/container/pool.hpp b/projects/rocprofiler-sdk/source/lib/common/container/pool.hpp new file mode 100644 index 00000000000..7bf410d41af --- /dev/null +++ b/projects/rocprofiler-sdk/source/lib/common/container/pool.hpp @@ -0,0 +1,204 @@ +// MIT License +// +// Copyright (c) 2026 Advanced Micro Devices, Inc. All Rights Reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include "lib/common/container/pool_object.hpp" +#include "lib/common/container/stable_vector.hpp" +#include "lib/common/defines.hpp" +#include "lib/common/demangle.hpp" +#include "lib/common/logging.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rocprofiler +{ +namespace common +{ +namespace container +{ +template +struct pool +{ + using size_type = size_t; + + // template + // explicit pool(Args&&... args) + // : m_pool{std::forward(args)...} + // {} + + template + explicit pool(std::piecewise_construct_t, size_type count, FuncT&& ctor, Args&&... args) + : m_count{count} + { + m_function = [this, + _ctor = std::forward(ctor), + _args_tuple = std::make_tuple(std::forward(args)...)]() { + for(size_type i = 0; i < m_count; ++i) + { + auto idx = m_pool.size(); + m_pool.emplace_back(idx, false, this); + std::apply( + [&](auto&&... unpacked_args) { + _ctor(m_pool[idx].get(), + std::forward(unpacked_args)...); + }, + _args_tuple); + m_available.push(idx); + } + }; + + m_function(); + } + + pool() = default; + ~pool() = default; + pool(const pool&) = delete; + pool(pool&&) noexcept = default; + pool& operator=(const pool&) = delete; + pool& operator=(pool&&) noexcept = default; + + // get an object from the pool. if all objects are in use, a new one will be created and added + // to the pool + pool_object& acquire(); + void release(size_type idx); + + template + pool_object& acquire(FuncT&& ctor, Args&&... args); + + void report_reuse() + { + ROCP_WARNING << fmt::format("Pool of type {}: Total pool size: {}. Reused objects: {}. " + "Released objects: {}. New batches: {}.", + cxx_demangle(typeid(Tp).name()), + m_pool.size(), + m_reused.load(), + m_released.load(), + m_new_batch.load()); + } + +private: + size_type m_count = 256; + std::function m_function = nullptr; + mutable std::shared_mutex m_pool_mtx = {}; + stable_vector, 32> m_pool = {}; + mutable std::shared_mutex m_available_mtx = {}; + std::queue m_available = {}; + std::atomic m_released = 0; + std::atomic m_reused = 0; + std::atomic m_new_batch = 0; +}; + +template +pool_object& +pool::acquire() +{ + auto _idx = std::optional{}; + { + auto _read_lk = std::shared_lock{m_available_mtx}; + if(!m_available.empty()) + { + _read_lk.unlock(); + auto _write_lk = std::unique_lock{m_available_mtx}; + _idx = m_available.front(); + m_available.pop(); + if(m_released > 0) + { + m_reused++; + m_released--; + } + } + } + + if(_idx.has_value()) + { + auto _read_lk = std::shared_lock{m_available_mtx}; + auto& _obj = m_pool.at(_idx.value()); + ROCP_FATAL_IF(!_obj.acquire()) << fmt::format( + "Pool object at index {} was expected to be available but was not", _idx.value()); + return _obj; + } + + // add a new batch + { + auto _write_pool_lk = std::unique_lock{m_pool_mtx}; + auto _write_avail_lk = std::unique_lock{m_available_mtx}; + ROCP_WARNING << fmt::format( + "Pool of type {} exhausted. Creating new batch of {} objects. New pool size: {}", + cxx_demangle(typeid(Tp).name()), + m_count, + m_pool.size() + m_count); + m_new_batch++; + m_function(); + } + + return acquire(); + // auto _idx_v = m_pool.size(); + // auto& _ref = m_pool.emplace_back(_idx_v, true, this); + // ROCP_INFO << fmt::format("Pool of type {} exhausted. Creating new object. New pool size: {}", + // typeid(Tp).name(), + // m_pool.size()); + // return _ref; +} + +template +void +pool::release(size_type idx) +{ + if(idx < m_pool.size()) + { + auto _write_lk = std::unique_lock{m_available_mtx}; + ROCP_FATAL_IF(m_pool.at(idx).in_use()) + << fmt::format("Pool object at index {} was expected to be not in use", idx); + // ROCP_WARNING << fmt::format( + // "Releasing object at index {} back to pool of type {}", idx, typeid(Tp).name()); + m_available.push(idx); + m_released++; + } +} + +// get an object from the pool. if all objects are in use, a new one will be created and added to +// the pool +template +template +pool_object& +pool::acquire(FuncT&& ctor, Args&&... args) +{ + auto& _ref = acquire(); + ctor(_ref.get(), std::forward(args)...); + return _ref; +} +} // namespace container +} // namespace common +} // namespace rocprofiler diff --git a/projects/rocprofiler-sdk/source/lib/common/container/pool_object.hpp b/projects/rocprofiler-sdk/source/lib/common/container/pool_object.hpp new file mode 100644 index 00000000000..e475c7d5957 --- /dev/null +++ b/projects/rocprofiler-sdk/source/lib/common/container/pool_object.hpp @@ -0,0 +1,121 @@ +// MIT License +// +// Copyright (c) 2026 Advanced Micro Devices, Inc. All Rights Reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include "lib/common/defines.hpp" +#include "lib/common/logging.hpp" +#include "rocprofiler-sdk/cxx/utility.hpp" + +#include +#include + +#include +#include +#include + +namespace rocprofiler +{ +namespace common +{ +namespace container +{ +template +struct pool; + +template +struct pool_object +{ + using pool_type = pool; + + pool_object(size_t idx, bool in_use, pool_type* pool) + : m_in_use{in_use} + , m_index{idx} + , m_pool{pool} + {} + + pool_object() = default; + ~pool_object() = default; + pool_object(pool_object&&) noexcept = default; + pool_object& operator=(pool_object&&) noexcept = default; + + // pool_object(const pool_object& rhs) = delete; + // pool_object& operator=(const pool_object& rhs) = delete; + + pool_object(const pool_object& rhs) + : m_object{rhs.m_object} + , m_in_use{rhs.m_in_use.load(std::memory_order_relaxed)} + , m_index{rhs.m_index} + , m_pool{rhs.m_pool} + {} + + pool_object& operator=(const pool_object& rhs) + { + if(this != &rhs) + { + m_object = rhs.m_object; + m_in_use.store(rhs.m_in_use.load(std::memory_order_relaxed), std::memory_order_relaxed); + m_index = rhs.m_index; + m_pool = rhs.m_pool; + } + return *this; + } + + bool acquire(); + bool release(); + bool in_use() const { return m_in_use.load(std::memory_order_relaxed); } + + Tp& get() { return m_object; } + const Tp& get() const { return m_object; } + + auto index() const { return m_index; } + auto index(size_t index) { m_index = index; } + +private: + Tp m_object = {}; + std::atomic m_in_use = false; + size_t m_index = 0; + pool_type* m_pool = nullptr; +}; + +template +bool +pool_object::acquire() +{ + bool expected = false; + return m_in_use.compare_exchange_strong(expected, true); +} + +template +bool +pool_object::release() +{ + bool expected = true; + auto val = m_in_use.compare_exchange_strong(expected, false); + + if(m_pool) m_pool->release(m_index); + + return val; +} +} // namespace container +} // namespace common +} // namespace rocprofiler diff --git a/projects/rocprofiler-sdk/source/lib/common/container/record_header_buffer.cpp b/projects/rocprofiler-sdk/source/lib/common/container/record_header_buffer.cpp index 0e2cd2b8832..bc2eef0a7f2 100644 --- a/projects/rocprofiler-sdk/source/lib/common/container/record_header_buffer.cpp +++ b/projects/rocprofiler-sdk/source/lib/common/container/record_header_buffer.cpp @@ -26,6 +26,7 @@ #include #include +#include #include namespace rocprofiler::common::container @@ -106,12 +107,13 @@ record_header_buffer::clear() { auto _sz = m_buffer.capacity(); if(!m_buffer.clear(std::nothrow_t{})) return 0; - std::for_each(m_headers.begin(), m_headers.end(), [](auto& itr) { - rocprofiler_record_header_t record = {}; - record.hash = 0; - record.payload = nullptr; - itr = record; - }); + // Only clear the used portion of m_headers (first _n elements) + // m_index is atomically incremented during every emplace, so it should + // indicate the number of used elements. + if(_n > 0) + { + std::memset(m_headers.data(), 0, _n * sizeof(rocprofiler_record_header_t)); + } rocprofiler_record_header_t record = {}; record.hash = 0; record.payload = nullptr; diff --git a/projects/rocprofiler-sdk/source/lib/common/container/small_vector.hpp b/projects/rocprofiler-sdk/source/lib/common/container/small_vector.hpp index c16938cfc6b..d0f12d8e97a 100644 --- a/projects/rocprofiler-sdk/source/lib/common/container/small_vector.hpp +++ b/projects/rocprofiler-sdk/source/lib/common/container/small_vector.hpp @@ -38,6 +38,7 @@ */ #include "lib/common/defines.hpp" +#include "lib/common/mpl.hpp" #include #include @@ -297,6 +298,8 @@ class small_vector_template_common : public small_vector_base::first_type; // will be void if not pair + using mapped_type = typename mpl::is_pair::second_type; // will be void if not pair using const_reverse_iterator = std::reverse_iterator; using reverse_iterator = std::reverse_iterator; @@ -378,6 +381,49 @@ class small_vector_template_common : public small_vector_base + typename std::enable_if_t::value && std::is_convertible_v, + pointer> + find(KeyT&& key) + { + return std::find_if(begin(), end(), [&key](const auto& itr) { + return itr.first == std::forward(key); + }); + } + template + typename std::enable_if_t::value && std::is_convertible_v, + const_pointer> + find(KeyT&& key) const + { + return std::find_if(begin(), end(), [&key](const auto& itr) { + return itr.first == std::forward(key); + }); + } + + template ::value && std::is_convertible_v && + !std::is_integral_v, + int> = 0> + auto& at(KeyT&& key) + { + auto* val = find(std::forward(key)); + if(val == end()) throw std::out_of_range{"small_vector::at(key_type)"}; + return val->second; + } + template ::value && std::is_convertible_v && + !std::is_integral_v, + int> = 0> + const auto& at(KeyT&& key) const + { + const auto* val = find(std::forward(key)); + if(val == end()) throw std::out_of_range{"small_vector::at(key_type)"}; + return val->second; + } }; /// small_vector_template_base - this is where we put @@ -1063,6 +1109,27 @@ class small_vector_impl : public small_vector_template_base return this->back(); } + // Specializations for small_vector of pairs, allows emplacement of pair + template + typename std::enable_if_t::value, std::pair> emplace( + Args&&... args) + { + auto key_value_pair = T{std::forward(args)...}; + + // Search for existing element by key + iterator itr = + std::find_if(this->begin(), this->end(), [&key_value_pair](const auto& existing) { + return existing.first == key_value_pair.first; + }); + + // If key already exists, return iterator to existing and false + if(itr != this->end()) return std::make_pair(itr, false); + + // Key not found, insert it and return iterator to new element and true + auto& ref = emplace_back(std::move(key_value_pair)); + return std::make_pair(&ref, true); + } + small_vector_impl& operator=(const small_vector_impl& RHS); small_vector_impl& operator=(small_vector_impl&& RHS) noexcept; diff --git a/projects/rocprofiler-sdk/source/lib/common/container/stable_vector.hpp b/projects/rocprofiler-sdk/source/lib/common/container/stable_vector.hpp index 78e538a2877..cd4a11ed509 100644 --- a/projects/rocprofiler-sdk/source/lib/common/container/stable_vector.hpp +++ b/projects/rocprofiler-sdk/source/lib/common/container/stable_vector.hpp @@ -47,7 +47,7 @@ struct reserve_size : value{_v} {} - size_t value; + size_t value = 0; }; template @@ -237,7 +237,7 @@ class stable_vector void add_chunk(); chunk_type& last_chunk(); - storage_type m_chunks; + storage_type m_chunks = {}; }; template diff --git a/projects/rocprofiler-sdk/source/lib/common/container/static_vector.hpp b/projects/rocprofiler-sdk/source/lib/common/container/static_vector.hpp index 075ced5d4f4..c2db32692b1 100644 --- a/projects/rocprofiler-sdk/source/lib/common/container/static_vector.hpp +++ b/projects/rocprofiler-sdk/source/lib/common/container/static_vector.hpp @@ -204,10 +204,7 @@ static_vector::emplace_back(Args&&... _v) if constexpr(sizeof...(Args) > 0) { - if constexpr(std::is_assignable(_v))...>::value) - m_data[_idx] = {std::forward(_v)...}; - else - m_data[_idx] = Tp{std::forward(_v)...}; + m_data[_idx] = Tp{std::forward(_v)...}; } else if constexpr(std::is_move_assignable::value || std::is_copy_assignable::value) { diff --git a/projects/rocprofiler-sdk/source/lib/common/mpl.hpp b/projects/rocprofiler-sdk/source/lib/common/mpl.hpp index 97b4d9e61b3..bed474b2d05 100644 --- a/projects/rocprofiler-sdk/source/lib/common/mpl.hpp +++ b/projects/rocprofiler-sdk/source/lib/common/mpl.hpp @@ -96,12 +96,16 @@ template struct is_pair_impl { static constexpr auto value = false; + using first_type = void; + using second_type = void; }; template struct is_pair_impl> { static constexpr auto value = true; + using first_type = LhsT; + using second_type = RhsT; }; template diff --git a/projects/rocprofiler-sdk/source/lib/common/utility.hpp b/projects/rocprofiler-sdk/source/lib/common/utility.hpp index 4c3591187a6..dbd4a8bf579 100644 --- a/projects/rocprofiler-sdk/source/lib/common/utility.hpp +++ b/projects/rocprofiler-sdk/source/lib/common/utility.hpp @@ -102,20 +102,22 @@ get_process_start_time_ns(pid_t _pid); std::vector read_command_line(pid_t _pid); -template +// supports all STL containers with find method and small_vector +template const auto* -get_val(const Container& map, const Key& key) +get_val(const ContainerT& data, const KeyT& key) { - auto pos = map.find(key); - return (pos != map.end() ? &pos->second : nullptr); + auto pos = data.find(key); + return (pos != data.end() ? &pos->second : nullptr); } -template +// supports all STL containers with find method and small_vector +template auto* -get_val(Container& map, const Key& key) +get_val(ContainerT& data, const KeyT& key) { - auto pos = map.find(key); - return (pos != map.end() ? &pos->second : nullptr); + auto pos = data.find(key); + return (pos != data.end() ? &pos->second : nullptr); } template diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/buffer.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/buffer.cpp index 31575c87a93..19bb8b843a1 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/buffer.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/buffer.cpp @@ -90,13 +90,11 @@ get_buffer(rocprofiler_buffer_id_t buffer_id) { if(is_valid_buffer_id(buffer_id) && get_buffers()) { - for(auto& itr : *get_buffers()) - { - if(itr && itr->buffer_id == buffer_id.handle) - { - return itr.get(); - } - } + // Use direct indexing instead of linear search (same pattern as destroy_buffer) + // See allocate_buffer below that the idx is assigned based on the size + address + auto idx = buffer_id.handle - get_buffer_offset(); + auto& buf = get_buffers()->at(idx); + return buf ? buf.get() : nullptr; } return nullptr; } diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/code_object/code_object.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/code_object/code_object.cpp index 475508379f8..6b4f3631a1a 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/code_object/code_object.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/code_object/code_object.cpp @@ -62,7 +62,7 @@ namespace { using context_t = context::context; using context_array_t = common::container::small_vector; -using external_corr_id_map_t = std::unordered_map; +using external_corr_id_map_t = tracing::external_correlation_id_map_t; template struct code_object_info; diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/core.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/core.cpp index b1a0e6d70bf..74dbfc46d53 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/core.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/core.cpp @@ -168,13 +168,13 @@ start_context(const context::context* ctx) if(cb->queue_id != rocprofiler::hsa::ClientID{-1}) continue; cb->queue_id = controller->add_callback( std::nullopt, - [=](const hsa::Queue& q, - const hsa::rocprofiler_packet& kern_pkt, - rocprofiler_kernel_id_t kernel_id, - rocprofiler_dispatch_id_t dispatch_id, - rocprofiler_user_data_t* user_data, - const hsa::Queue::queue_info_session_t::external_corr_id_map_t& extern_corr_ids, - const context::correlation_id* correlation_id) { + [=](const hsa::Queue& q, + const hsa::rocprofiler_packet& kern_pkt, + rocprofiler_kernel_id_t kernel_id, + rocprofiler_dispatch_id_t dispatch_id, + rocprofiler_user_data_t* user_data, + const hsa::queue_info_session_t::external_corr_id_map_t& extern_corr_ids, + const context::correlation_id* correlation_id) { return queue_cb(ctx, cb, q, @@ -188,10 +188,11 @@ start_context(const context::context* ctx) // Completion CB [=](const hsa::Queue& /* q */, hsa::rocprofiler_packet /* kern_pkt */, - std::shared_ptr& session, - inst_pkt_t& aql, - kernel_dispatch::profiling_time dispatch_time) { - completed_cb(ctx, cb, session, aql, dispatch_time); + std::shared_ptr& session, + hsa::packet_data_t& packet, + inst_pkt_t& aql, + kernel_dispatch::profiling_time dispatch_time) { + completed_cb(ctx, cb, session, packet, aql, dispatch_time); }); } } diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/dispatch_handlers.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/dispatch_handlers.cpp index 1b7cb3301c7..5fd6b8db18b 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/dispatch_handlers.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/dispatch_handlers.cpp @@ -47,15 +47,15 @@ namespace counters * We return an AQLPacket containing the start/stop/read packets for injection. */ hsa::Queue::pkt_and_serialize_t -queue_cb(const context::context* ctx, - const std::shared_ptr& info, - const hsa::Queue& queue, - const hsa::rocprofiler_packet& pkt, - rocprofiler_kernel_id_t kernel_id, - rocprofiler_dispatch_id_t dispatch_id, - rocprofiler_user_data_t* user_data, - const hsa::Queue::queue_info_session_t::external_corr_id_map_t& extern_corr_ids, - const context::correlation_id* correlation_id) +queue_cb(const context::context* ctx, + const std::shared_ptr& info, + const hsa::Queue& queue, + const hsa::rocprofiler_packet& pkt, + rocprofiler_kernel_id_t kernel_id, + rocprofiler_dispatch_id_t dispatch_id, + rocprofiler_user_data_t* user_data, + const hsa::queue_info_session_t::external_corr_id_map_t& extern_corr_ids, + const context::correlation_id* correlation_id) { CHECK(info && ctx); @@ -139,11 +139,12 @@ queue_cb(const context::context* ctx, * Callback called by HSA interceptor when the kernel has completed processing. */ void -completed_cb(const context::context* ctx, - const std::shared_ptr& info, - std::shared_ptr& ptr_session, - inst_pkt_t& pkts, - kernel_dispatch::profiling_time dispatch_time) +completed_cb(const context::context* ctx, + const std::shared_ptr& info, + std::shared_ptr& ptr_session, + hsa::packet_data_t& packet, + inst_pkt_t& pkts, + kernel_dispatch::profiling_time dispatch_time) { CHECK(info && ctx); @@ -167,8 +168,8 @@ completed_cb(const context::context* ctx, // We have no profile config, nothing to output. if(!pkt || !prof_config) return; - completed_cb_params_t params{info, ptr_session, dispatch_time, prof_config, std::move(pkt)}; - process_callback_data(std::move(params)); + process_callback_data(completed_cb_params_t{ + info, ptr_session, &packet, dispatch_time, prof_config, std::move(pkt)}); } } // namespace counters diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/dispatch_handlers.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/dispatch_handlers.hpp index f03a424135c..f3bb841973c 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/dispatch_handlers.hpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/dispatch_handlers.hpp @@ -35,22 +35,23 @@ using inst_pkt_t = common::container:: small_vector, ClientID>, 4>; hsa::Queue::pkt_and_serialize_t -queue_cb(const context::context* ctx, - const std::shared_ptr& info, - const hsa::Queue& queue, - const hsa::rocprofiler_packet& pkt, - rocprofiler_kernel_id_t kernel_id, - rocprofiler_dispatch_id_t dispatch_id, - rocprofiler_user_data_t* user_data, - const hsa::Queue::queue_info_session_t::external_corr_id_map_t& extern_corr_ids, - const context::correlation_id* correlation_id); +queue_cb(const context::context* ctx, + const std::shared_ptr& info, + const hsa::Queue& queue, + const hsa::rocprofiler_packet& pkt, + rocprofiler_kernel_id_t kernel_id, + rocprofiler_dispatch_id_t dispatch_id, + rocprofiler_user_data_t* user_data, + const hsa::queue_info_session_t::external_corr_id_map_t& extern_corr_ids, + const context::correlation_id* correlation_id); void -completed_cb(const context::context* ctx, - const std::shared_ptr& info, - std::shared_ptr& session, - inst_pkt_t& pkts, - kernel_dispatch::profiling_time dispatch_time); +completed_cb(const context::context* ctx, + const std::shared_ptr& info, + std::shared_ptr& session, + hsa::packet_data_t& packet, + inst_pkt_t& pkts, + kernel_dispatch::profiling_time dispatch_time); } // namespace counters } // namespace rocprofiler diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/sample_processing.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/sample_processing.cpp index 29883ae0af9..cdd07cf1c4e 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/sample_processing.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/sample_processing.cpp @@ -53,11 +53,12 @@ get_buffer_mut() void proccess_completed_cb(completed_cb_params_t&& params) { - auto& info = params.info; - auto& session = *params.session; - auto& dispatch_time = params.dispatch_time; - auto& prof_config = params.prof_config; - auto& pkt = params.pkt; + auto& info = params.info; + auto& session = *params.session; + const auto& packet = *params.packet_data; + auto& dispatch_time = params.dispatch_time; + auto& prof_config = params.prof_config; + auto& pkt = params.pkt; ROCP_FATAL_IF(pkt == nullptr) << "AQL packet is a nullptr!"; @@ -82,13 +83,13 @@ proccess_completed_cb(completed_cb_params_t&& params) { _corr_id_v.internal = _corr_id->internal; if(const auto* external = rocprofiler::common::get_val( - session.tracing_data.external_correlation_ids, info->internal_context)) + packet.tracing_data.external_correlation_ids, info->internal_context)) { _corr_id_v.external = *external; } } - auto _dispatch_id = session.callback_record.dispatch_info.dispatch_id; + auto _dispatch_id = packet.callback_record.dispatch_info.dispatch_id; for(auto& ast : prof_config->asts) { std::vector>> cache; @@ -118,7 +119,7 @@ proccess_completed_cb(completed_cb_params_t&& params) _header.start_timestamp = dispatch_time.start; _header.end_timestamp = dispatch_time.end; } - _header.dispatch_info = session.callback_record.dispatch_info; + _header.dispatch_info = packet.callback_record.dispatch_info; auto _lk = std::unique_lock{get_buffer_mut()}; // Buffer records need to be in order @@ -137,7 +138,7 @@ proccess_completed_cb(completed_cb_params_t&& params) auto dispatch_data = common::init_public_api_struct(rocprofiler_dispatch_counting_service_data_t{}); - dispatch_data.dispatch_info = session.callback_record.dispatch_info; + dispatch_data.dispatch_info = packet.callback_record.dispatch_info; dispatch_data.correlation_id = _corr_id_v; if(dispatch_time.status == HSA_STATUS_SUCCESS) { @@ -148,7 +149,7 @@ proccess_completed_cb(completed_cb_params_t&& params) info->record_callback(dispatch_data, out.data(), out.size(), - session.user_data, + packet.user_data, info->record_callback_args); } } diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/sample_processing.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/sample_processing.hpp index 4aff58930dc..01c68ddd0d8 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/sample_processing.hpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/sample_processing.hpp @@ -24,6 +24,7 @@ #include "lib/rocprofiler-sdk/context/context.hpp" #include "lib/rocprofiler-sdk/hsa/aql_packet.hpp" +#include "lib/rocprofiler-sdk/hsa/queue_info_session.hpp" #include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp" namespace rocprofiler @@ -32,11 +33,12 @@ namespace counters { struct completed_cb_params_t { - std::shared_ptr info; - std::shared_ptr session; - kernel_dispatch::profiling_time dispatch_time; - std::shared_ptr prof_config; - std::unique_ptr pkt; + std::shared_ptr info; + std::shared_ptr session; + const hsa::packet_data_t* packet_data = nullptr; // owned by session + kernel_dispatch::profiling_time dispatch_time; + std::shared_ptr prof_config; + std::unique_ptr pkt; }; void diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/tests/core.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/tests/core.cpp index e10daf63d72..342935446f1 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/tests/core.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/counters/tests/core.cpp @@ -432,7 +432,7 @@ TEST(core, check_callbacks) expected.queue_id = qid; expected.agent_id = fq.get_agent().get_rocp_agent()->id; - hsa::Queue::queue_info_session_t::external_corr_id_map_t extern_ids = {}; + hsa::queue_info_session_t::external_corr_id_map_t extern_ids = {}; auto user_data = rocprofiler_user_data_t{.value = corr_id.internal}; auto ret_pkt = counters::queue_cb(&ctx, @@ -463,15 +463,16 @@ TEST(core, check_callbacks) "Could not create buffer"); cb_info->buffer = opt_buff_id; - auto _sess = hsa::Queue::queue_info_session_t{.queue = fq}; + auto _sess = hsa::queue_info_session_t{.queue = fq}; _sess.correlation_id = &corr_id; - auto sess = std::make_shared(std::move(_sess)); + auto sess = std::make_shared(std::move(_sess)); + auto& packet_data = sess->packet_data.emplace_back(); counters::inst_pkt_t pkts; pkts.emplace_back( std::make_pair(std::move(ret_pkt.pkt), static_cast(0))); - completed_cb(&ctx, cb_info, sess, pkts, kernel_dispatch::profiling_time{}); + completed_cb(&ctx, cb_info, sess, packet_data, pkts, kernel_dispatch::profiling_time{}); rocprofiler_flush_buffer(opt_buff_id); rocprofiler_destroy_buffer(opt_buff_id); } @@ -774,7 +775,7 @@ rocprofiler-sdk: description: cycles properties: [] definitions: - - architectures: + - architectures: - gfx950 - gfx942 - gfx10 @@ -822,12 +823,12 @@ TEST(core, check_load_counter_def) const std::string test_yaml = R"( rocprofiler-sdk: counters-schema-version: 1 - counters: + counters: - name: GRBM_GUI_ACTIVE description: The GUI is Active properties: [] definitions: - - architectures: + - architectures: - gfx950 - gfx942 - gfx941 @@ -859,7 +860,7 @@ rocprofiler-sdk: description: cycles properties: [] definitions: - - architectures: + - architectures: - gfx950 - gfx942 - gfx10 diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/CMakeLists.txt b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/CMakeLists.txt index 13f21429dd3..3e24c5e6579 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/CMakeLists.txt +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/CMakeLists.txt @@ -28,6 +28,7 @@ set(ROCPROFILER_LIB_HSA_HEADERS queue_info_session.hpp rocprofiler_packet.hpp scratch_memory.hpp + signal.hpp utils.hpp) target_sources(rocprofiler-sdk-object-library PRIVATE ${ROCPROFILER_LIB_HSA_SOURCES} diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/async_copy.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/async_copy.cpp index cfc3fc93580..812ba3019f2 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/async_copy.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/async_copy.cpp @@ -65,7 +65,7 @@ namespace { using context_t = context::context; using context_array_t = common::container::small_vector; -using external_corr_id_map_t = std::unordered_map; +using external_corr_id_map_t = tracing::external_correlation_id_map_t; template struct async_copy_info; diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/memory_allocation.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/memory_allocation.cpp index 59a43b611ca..a1058a43c94 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/memory_allocation.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/memory_allocation.cpp @@ -63,7 +63,7 @@ namespace memory_allocation namespace { using context_t = context::context; -using external_corr_id_map_t = std::unordered_map; +using external_corr_id_map_t = tracing::external_correlation_id_map_t; using region_to_agent_map = std::unordered_map; using memory_pool_to_agent_map = std::unordered_map; using region_to_agent_pair = std::pair; diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue.cpp index 0536b481404..10643ebaebd 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue.cpp @@ -21,13 +21,17 @@ THE SOFTWARE. */ #include "lib/rocprofiler-sdk/hsa/queue.hpp" +#include "lib/common/container/pool.hpp" +#include "lib/common/logging.hpp" #include "lib/common/scope_destructor.hpp" +#include "lib/common/static_object.hpp" #include "lib/common/utility.hpp" #include "lib/rocprofiler-sdk/code_object/code_object.hpp" #include "lib/rocprofiler-sdk/context/context.hpp" #include "lib/rocprofiler-sdk/hsa/details/fmt.hpp" #include "lib/rocprofiler-sdk/hsa/hsa.hpp" #include "lib/rocprofiler-sdk/hsa/queue_controller.hpp" +#include "lib/rocprofiler-sdk/hsa/queue_info_session.hpp" #include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp" #include "lib/rocprofiler-sdk/kernel_dispatch/tracing.hpp" #include "lib/rocprofiler-sdk/pc_sampling/hsa_adapter.hpp" @@ -46,6 +50,7 @@ #include #include +#include // static assert for rocprofiler_packet ABI compatibility static_assert(sizeof(hsa_ext_amd_aql_pm4_packet_t) == sizeof(hsa_kernel_dispatch_packet_t), @@ -70,6 +75,8 @@ namespace hsa { namespace { +constexpr auto null_hsa_signal = hsa_signal_t{.handle = 0}; + template inline bool context_filter(const context::context* ctx, DomainT domain, Args... args) @@ -96,101 +103,183 @@ context_filter(const context::context* ctx) context_filter(ctx, ROCPROFILER_CALLBACK_TRACING_KERNEL_DISPATCH)); } +signal_t& +construct_hsa_signal(signal_t& signal, + hsa_signal_value_t initial_value = 0, + uint32_t num_consumers = 0, + const hsa_agent_t* consumers = nullptr, + uint64_t attributes = 0) +{ + auto status = HSA_STATUS_SUCCESS; + if(!get_amd_ext_table() || !get_amd_ext_table()->hsa_amd_signal_create_fn) + status = HSA_STATUS_ERROR; + else + status = get_amd_ext_table()->hsa_amd_signal_create_fn( + initial_value, num_consumers, consumers, attributes, &signal.value); + + ROCP_FATAL_IF(status != HSA_STATUS_SUCCESS) + << fmt::format("Error: hsa_amd_signal_create failed with error code {} :: {}", + static_cast(status), + hsa::get_hsa_status_string(status)); + + return signal; +} + +auto* +get_signal_pool() +{ + constexpr size_t default_signal_pool_size = (1 << 12); // 4096 signals per pool batch + + // static auto pool = common::container::pool{ + // std::piecewise_construct, default_signal_pool_size, [](Queue::signal_impl& signal) { + // construct_hsa_signal(signal, 0, 0, nullptr, 0); + // }}; + // return &pool; + + static auto*& pool = common::static_object>::construct( + std::piecewise_construct, default_signal_pool_size, [](signal_t& signal) { + construct_hsa_signal(signal, 0, 0, nullptr, 0); + }); + + return pool; +} + bool AsyncSignalHandler(hsa_signal_value_t /*signal_v*/, void* data) { - if(!data) return true; + using session_info_t = std::shared_ptr; + + if(!data) + { + ROCP_FATAL << "AsyncSignalHandler called with null data pointer"; + return true; + } + + auto* _session_ptr = static_cast(data); // if we have fully finalized, delete the data and return if(registration::get_fini_status() > 0) { - auto* _session = static_cast(data); - delete _session; + _session_ptr->reset(); + delete _session_ptr; + return true; + } + + // ROCP_WARNING << fmt::format("Pooled signal from pool: hsa_signal_t{{.handle={}}}", + // _pooled_signal->get().value.handle); + + // cleanup the pooled signal data and release the signal back to the pool for reuse + auto _cleanup = common::scope_destructor{[&_session_ptr]() { + _session_ptr->reset(); + delete _session_ptr; + _session_ptr = nullptr; + }}; + + auto _session = *_session_ptr; // make a copy of the shared pointer to extend lifetime for the + // duration of this function + if(!_session.get()) + { + ROCP_FATAL << fmt::format("nullptr to session information"); return false; } - auto& shared_ptr_info = *static_cast*>(data); - auto& queue_info_session = *shared_ptr_info; + auto& queue_info_session = *_session; + + // ROCP_WARNING << fmt::format("AsyncSignalHandler called for {} packets", + // queue_info_session.packet_data.size()); - auto dispatch_time = kernel_dispatch::get_dispatch_time(queue_info_session); + for(auto& packet : queue_info_session.packet_data) + { + auto dispatch_time = kernel_dispatch::get_dispatch_time(queue_info_session, packet); + kernel_dispatch::dispatch_complete(queue_info_session, packet, dispatch_time); - kernel_dispatch::dispatch_complete(queue_info_session, dispatch_time); + // Calls our internal callbacks to callers who need to be notified post + // kernel execution. + queue_info_session.queue.signal_callback([&](const auto& map) { + for(const auto& [client_id, cb_pair] : map) + { + cb_pair.second(queue_info_session.queue, + packet.kernel_packet, + _session, + packet, + packet.instrumentation_packets, + dispatch_time); + } + }); - // Calls our internal callbacks to callers who need to be notified post - // kernel execution. - queue_info_session.queue.signal_callback([&](const auto& map) { - for(const auto& [client_id, cb_pair] : map) + if(packet.is_serialized) { - cb_pair.second(queue_info_session.queue, - queue_info_session.kernel_pkt, - shared_ptr_info, - queue_info_session.inst_pkt, - dispatch_time); + CHECK_NOTNULL(hsa::get_queue_controller()) + ->serializer(&queue_info_session.queue) + .wlock([&](auto& serializer) { + serializer.kernel_completion_signal(queue_info_session.queue); + }); } - }); - if(queue_info_session.is_serialized) - { - CHECK_NOTNULL(hsa::get_queue_controller()) - ->serializer(&queue_info_session.queue) - .wlock([&](auto& serializer) { - serializer.kernel_completion_signal(queue_info_session.queue); - }); - } - - // Delete signals and packets, signal we have completed. - if(queue_info_session.interrupt_signal.handle != 0u) - { + // Delete signals and packets, signal we have completed. + if(packet.interrupt_signal.handle != 0u) + { #if !defined(NDEBUG) - CHECK_NOTNULL(hsa::get_queue_controller())->_debug_signals.wlock([&](auto& signals) { - signals.erase(queue_info_session.interrupt_signal.handle); - }); + CHECK_NOTNULL(hsa::get_queue_controller())->_debug_signals.wlock([&](auto& signals) { + signals.erase(packet.interrupt_signal.handle); + }); #endif - hsa::get_core_table()->hsa_signal_store_screlease_fn(queue_info_session.interrupt_signal, - -1); - hsa::get_core_table()->hsa_signal_destroy_fn(queue_info_session.interrupt_signal); - } - if(queue_info_session.kernel_pkt.ext_amd_aql_pm4.completion_signal.handle != 0u) - { - hsa::get_core_table()->hsa_signal_destroy_fn( - queue_info_session.kernel_pkt.ext_amd_aql_pm4.completion_signal); - } + hsa::get_core_table()->hsa_signal_store_screlease_fn(packet.interrupt_signal, -1); + ROCP_FATAL << "Destroying interrupt signal"; + hsa::get_core_table()->hsa_signal_destroy_fn(packet.interrupt_signal); + } - // we need to decrement this reference count at the end of the functions - auto* _corr_id = queue_info_session.correlation_id; - if(_corr_id) - { - ROCP_FATAL_IF(_corr_id->get_ref_count() == 0) - << "reference counter for correlation id " << _corr_id->internal << " from thread " - << _corr_id->thread_idx << " has no reference count"; - _corr_id->sub_kern_count(); - _corr_id->sub_ref_count(); + // if(packet.kernel_packet.ext_amd_aql_pm4.completion_signal.handle != 0u && + // packet.kernel_packet.ext_amd_aql_pm4.completion_signal.handle != + // _pooled_signal->get().value.handle) + // { + // hsa::get_core_table()->hsa_signal_destroy_fn( + // packet.kernel_packet.ext_amd_aql_pm4.completion_signal); + // } + + // we need to decrement this reference count at the end of the functions + auto* _corr_id = queue_info_session.correlation_id; + if(_corr_id) + { + ROCP_FATAL_IF(_corr_id->get_ref_count() == 0) + << "reference counter for correlation id " << _corr_id->internal << " from thread " + << _corr_id->thread_idx << " has no reference count"; + _corr_id->sub_kern_count(); + _corr_id->sub_ref_count(); + } + + Queue::release_signal(packet.pooled_signal); } + // for(auto& packet : queue_info_session.packet_data) + // Queue::release_signal(packet.pooled_signal); + queue_info_session.queue.async_complete(); - delete &shared_ptr_info; return false; } -template -constexpr Integral -bit_mask(int first, int last) -{ - assert(last >= first && "Error: hsa_support::bit_mask -> invalid argument"); - size_t num_bits = last - first + 1; - return ((num_bits >= sizeof(Integral) * 8) ? ~Integral{0} - /* num_bits exceed the size of Integral */ - : ((Integral{1} << num_bits) - 1)) - << first; -} - /* Extract bits [last:first] from t. */ template constexpr Integral bit_extract(Integral x, int first, int last) { - return (x >> first) & bit_mask(0, last - first); + static_assert(std::is_integral::value, "Integral type required"); + + auto&& bit_mask = [](int _first, int _last) { + ROCP_FATAL_IF(!(_last >= _first)) << fmt::format( + "[queue::bit_extract::bit_mask] -> invalid argument. last (={}) is not >= first (={})", + _last, + _first); + + size_t num_bits = _last - _first + 1; + return ((num_bits >= sizeof(Integral) * 8) ? ~Integral{0} + /* num_bits exceed the size of Integral */ + : ((Integral{1} << num_bits) - 1)) + << _first; + }; + + return (x >> first) & bit_mask(0, last - first); } /** @@ -213,14 +302,17 @@ WriteInterceptor(const void* packets, return; } - using callback_record_t = Queue::queue_info_session_t::callback_record_t; + ROCP_INFO << fmt::format("WriteInterceptor called with pkt_count={}", pkt_count); + + using callback_record_t = packet_data_t::callback_record_t; + using packet_vector_t = common::container::small_vector; // unique sequence id for the dispatch static auto sequence_counter = std::atomic{0}; - auto&& CreateBarrierPacket = [](hsa_signal_t* dependency_signal, - hsa_signal_t* completion_signal, - std::vector& _packets) { + auto&& CreateBarrierPacket = [](hsa_signal_t* dependency_signal, + hsa_signal_t* completion_signal, + packet_vector_t& _packets) { hsa_barrier_and_packet_t barrier{}; barrier.header = HSA_PACKET_TYPE_BARRIER_AND << HSA_PACKET_HEADER_TYPE; barrier.header |= 1 << HSA_PACKET_HEADER_BARRIER; @@ -241,21 +333,91 @@ WriteInterceptor(const void* packets, return; } - auto tracing_data_v = tracing::tracing_data{}; - tracing::populate_contexts(ROCPROFILER_CALLBACK_TRACING_KERNEL_DISPATCH, - ROCPROFILER_BUFFER_TRACING_KERNEL_DISPATCH, - tracing_data_v); + const auto* packets_arr = static_cast(packets); + auto num_dispatch_packets = size_t{0}; + for(size_t i = 0; i < pkt_count; ++i) + { + const auto& original_packet = packets_arr[i].kernel_dispatch; + auto packet_type = bit_extract(original_packet.header, + HSA_PACKET_HEADER_TYPE, + HSA_PACKET_HEADER_TYPE + HSA_PACKET_HEADER_WIDTH_TYPE - 1); + if(packet_type == HSA_PACKET_TYPE_KERNEL_DISPATCH) + { + ++num_dispatch_packets; + } + } + + if(num_dispatch_packets == 0) + { + writer(packets, pkt_count); + return; + } + // these are for the services (dispatch counter collection, pc sampling, ATT) which use // the queue/queue_controller callback mechanism const auto queue_callback_context_filter = [](const context::context* ctx) { return (ctx->counter_collection || ctx->pc_sampler || ctx->dispatch_thread_trace); }; + auto tracing_data_v = tracing::tracing_data{}; + tracing::populate_contexts(ROCPROFILER_CALLBACK_TRACING_KERNEL_DISPATCH, + ROCPROFILER_BUFFER_TRACING_KERNEL_DISPATCH, + tracing_data_v); + for(const auto* itr : context::get_active_contexts(queue_callback_context_filter)) tracing_data_v.external_correlation_ids.emplace(itr, tracing::empty_user_data); - const auto* packets_arr = static_cast(packets); - auto transformed_packets = std::vector{}; + auto transformed_packets = packet_vector_t{}; + + // mark the queue as having at least one packet which will be assigned a callback to + // AsyncSignalHandler. This is used to determine whether we need to wait for the signal handler + // to complete during finalization. + queue.async_started(); + + // all packets should have the same correlation id so we can just look at the first one to get + // the correlation id for the entire batch of packets + auto* corr_id = context::get_latest_correlation_id(); + context::correlation_id* _corr_id_pop = nullptr; + + // Allocate a correlation id if we have at least one dispatch packet and we don't have a + // correlation id already. There will not be a correlation id if there is no API tracing but it + // was requested by tools to always provide one. + if(!corr_id) + { + constexpr auto ref_count = 1; + corr_id = context::correlation_tracing_service::construct(ref_count); + _corr_id_pop = corr_id; + } + + // During finalization, correlation tracing service will not construct a correlation id so just + // write packet through without tracing + if(!corr_id) + { + writer(packets, pkt_count); + return; + } + + // if we constructed a correlation id, this decrements the reference count after the + // underlying function returns + auto _corr_id_dtor = common::scope_destructor{[_corr_id_pop]() { + if(_corr_id_pop) + { + context::pop_latest_correlation_id(_corr_id_pop); + _corr_id_pop->sub_ref_count(); + } + }}; + + auto thr_id = (corr_id) ? corr_id->thread_idx : common::get_tid(); + auto internal_corr_id = (corr_id) ? corr_id->internal : 0; + auto ancestor_corr_id = (corr_id) ? corr_id->ancestor : 0; + + using packet_data_array_t = queue_info_session_t::packet_data_array_t; + + auto _info_session = queue_info_session_t{.queue = queue, + .tid = thr_id, + .enqueue_ts = common::timestamp_ns(), + .correlation_id = corr_id, + .packet_data = packet_data_array_t{}}; // Searching accross all the packets given during this write for(size_t i = 0; i < pkt_count; ++i) @@ -270,61 +432,36 @@ WriteInterceptor(const void* packets, continue; } - auto* corr_id = context::get_latest_correlation_id(); - context::correlation_id* _corr_id_pop = nullptr; - - if(!corr_id) - { - constexpr auto ref_count = 1; - corr_id = context::correlation_tracing_service::construct(ref_count); - _corr_id_pop = corr_id; - } - - if(!corr_id) - { - // During finalization - just write packet through without tracing - transformed_packets.emplace_back(packets_arr[i]); - continue; - } - // increase the reference count to denote that this correlation id is being used in a kernel corr_id->add_ref_count(); corr_id->add_kern_count(); - auto thr_id = (corr_id) ? corr_id->thread_idx : common::get_tid(); - auto user_data = rocprofiler_user_data_t{.value = 0}; - auto internal_corr_id = (corr_id) ? corr_id->internal : 0; - auto ancestor_corr_id = (corr_id) ? corr_id->ancestor : 0; + auto _packet_data = packet_data_t{}; - // if we constructed a correlation id, this decrements the reference count after the - // underlying function returns - auto _corr_id_dtor = common::scope_destructor{[_corr_id_pop]() { - if(_corr_id_pop) - { - context::pop_latest_correlation_id(_corr_id_pop); - _corr_id_pop->sub_ref_count(); - } - }}; + // make a copy of the tracing data + _packet_data.tracing_data = tracing_data_v; tracing::populate_external_correlation_ids( - tracing_data_v.external_correlation_ids, + _packet_data.tracing_data.external_correlation_ids, thr_id, ROCPROFILER_EXTERNAL_CORRELATION_REQUEST_KERNEL_DISPATCH, ROCPROFILER_KERNEL_DISPATCH_ENQUEUE, internal_corr_id); - queue.async_started(); - const auto original_completion_signal = original_packet.completion_signal; const bool existing_completion_signal = (original_completion_signal.handle != 0); const uint64_t kernel_id = code_object::get_kernel_id(original_packet.kernel_object); // Copy kernel pkt, copy is to allow for signal to be modified - rocprofiler_packet kernel_pkt = packets_arr[i]; + _packet_data.kernel_packet = packets_arr[i]; + // create a referencce for short hand access + auto& kernel_packet = _packet_data.kernel_packet; + // create our own signal that we can get a callback on. if there is an original completion // signal we will create a barrier packet, assign the original completion signal that that // barrier packet, and add it right after the kernel packet - queue.create_signal(0, &kernel_pkt.kernel_dispatch.completion_signal); + _packet_data.pooled_signal = + queue.create_signal(0, &kernel_packet.kernel_dispatch.completion_signal, true); // computes the "size" based on the offset of reserved_padding field constexpr auto kernel_dispatch_info_rt_size = @@ -333,8 +470,8 @@ WriteInterceptor(const void* packets, static_assert(kernel_dispatch_info_rt_size < sizeof(rocprofiler_kernel_dispatch_info_t), "failed to compute size field based on offset of reserved_padding field"); - auto dispatch_id = ++sequence_counter; - auto callback_record = callback_record_t{ + auto dispatch_id = ++sequence_counter; + _packet_data.callback_record = callback_record_t{ sizeof(callback_record_t), rocprofiler_timestamp_t{0}, rocprofiler_timestamp_t{0}, @@ -344,71 +481,75 @@ WriteInterceptor(const void* packets, .queue_id = queue.get_id(), .kernel_id = kernel_id, .dispatch_id = dispatch_id, - .private_segment_size = kernel_pkt.kernel_dispatch.private_segment_size, - .group_segment_size = kernel_pkt.kernel_dispatch.group_segment_size, - .workgroup_size = rocprofiler_dim3_t{kernel_pkt.kernel_dispatch.workgroup_size_x, - kernel_pkt.kernel_dispatch.workgroup_size_y, - kernel_pkt.kernel_dispatch.workgroup_size_z}, - .grid_size = rocprofiler_dim3_t{kernel_pkt.kernel_dispatch.grid_size_x, - kernel_pkt.kernel_dispatch.grid_size_y, - kernel_pkt.kernel_dispatch.grid_size_z}, + .private_segment_size = kernel_packet.kernel_dispatch.private_segment_size, + .group_segment_size = kernel_packet.kernel_dispatch.group_segment_size, + .workgroup_size = + rocprofiler_dim3_t{kernel_packet.kernel_dispatch.workgroup_size_x, + kernel_packet.kernel_dispatch.workgroup_size_y, + kernel_packet.kernel_dispatch.workgroup_size_z}, + .grid_size = rocprofiler_dim3_t{kernel_packet.kernel_dispatch.grid_size_x, + kernel_packet.kernel_dispatch.grid_size_y, + kernel_packet.kernel_dispatch.grid_size_z}, .reserved_padding = {0}}}; { - auto tracer_data = callback_record; - tracing::execute_phase_enter_callbacks(tracing_data_v.callback_contexts, - thr_id, - internal_corr_id, - tracing_data_v.external_correlation_ids, - ancestor_corr_id, - ROCPROFILER_CALLBACK_TRACING_KERNEL_DISPATCH, - ROCPROFILER_KERNEL_DISPATCH_ENQUEUE, - tracer_data); + auto tracer_data = _packet_data.callback_record; + tracing::execute_phase_enter_callbacks( + _packet_data.tracing_data.callback_contexts, + thr_id, + internal_corr_id, + _packet_data.tracing_data.external_correlation_ids, + ancestor_corr_id, + ROCPROFILER_CALLBACK_TRACING_KERNEL_DISPATCH, + ROCPROFILER_KERNEL_DISPATCH_ENQUEUE, + tracer_data); } // map all the external correlation ids (after enqueue enter phase) for all the contexts // captured by the info session tracing::update_external_correlation_ids( - tracing_data_v.external_correlation_ids, + _packet_data.tracing_data.external_correlation_ids, thr_id, ROCPROFILER_EXTERNAL_CORRELATION_REQUEST_KERNEL_DISPATCH); // Stores the instrumentation pkt (i.e. AQL packets for counter collection) // along with an ID of the client we got the packet from (this will be returned via // completed_cb_t) - auto inst_pkt = inst_pkt_t{}; - // True if any service (ATT,SPM,CC) requests this dispatch to be serialized - bool bRequest_Serialize = false; - - // Signal callbacks that a kernel_pkt is being enqueued + // Signal callbacks that a kernel_packet is being enqueued queue.signal_callback([&](const auto& map) { for(const auto& [client_id, cb_pair] : map) { - auto [packet, bSerial] = cb_pair.first(queue, - kernel_pkt, - kernel_id, - dispatch_id, - &user_data, - tracing_data_v.external_correlation_ids, - corr_id); - bRequest_Serialize |= bSerial; - if(packet) inst_pkt.push_back(std::make_pair(std::move(packet), client_id)); + // NOTE: if map.size() > 1, multiple callbacks will be sharing the same user data. + // This needs to be fixed. (bewelton) + auto [packet, bSerial] = + cb_pair.first(queue, + kernel_packet, + kernel_id, + dispatch_id, + &_packet_data.user_data, + _packet_data.tracing_data.external_correlation_ids, + corr_id); + _packet_data.is_serialized |= bSerial; + if(packet) + _packet_data.instrumentation_packets.push_back( + std::make_pair(std::move(packet), client_id)); } }); bool inserted_before = false; - if(bRequest_Serialize) + if(_packet_data.is_serialized) { inserted_before = true; CHECK_NOTNULL(hsa::get_queue_controller()) ->serializer(&queue) .rlock([&](const auto& serializer) { for(auto& s_pkt : serializer.kernel_dispatch(queue)) - transformed_packets.emplace_back(s_pkt.ext_amd_aql_pm4); + transformed_packets.emplace_back(s_pkt.kernel_dispatch); }); } - for(const auto& pkt_injection : inst_pkt) + + for(const auto& pkt_injection : _packet_data.instrumentation_packets) { for(const auto& pkt : pkt_injection.first->before_krn_pkt) { @@ -421,12 +562,13 @@ WriteInterceptor(const void* packets, if(pc_sampling::is_pc_sample_service_configured(queue.get_agent().get_rocp_agent()->id)) { transformed_packets.emplace_back(pc_sampling::hsa::generate_marker_packet_for_kernel( - corr_id, tracing_data_v.external_correlation_ids, dispatch_id)); + corr_id, _packet_data.tracing_data.external_correlation_ids, dispatch_id)); } #endif // emplace the kernel packet - transformed_packets.emplace_back(kernel_pkt); + transformed_packets.emplace_back(kernel_packet); + // If a profiling packet was inserted, wait for completion before executing the dispatch if(inserted_before) transformed_packets.back().kernel_dispatch.header |= 1 << HSA_PACKET_HEADER_BARRIER; @@ -442,7 +584,7 @@ WriteInterceptor(const void* packets, } bool injected_end_pkt = false; - for(const auto& pkt_injection : inst_pkt) + for(const auto& pkt_injection : _packet_data.instrumentation_packets) { for(const auto& pkt : pkt_injection.first->after_krn_pkt) { @@ -451,19 +593,19 @@ WriteInterceptor(const void* packets, } } - auto completion_signal = hsa_signal_t{.handle = 0}; - auto interrupt_signal = hsa_signal_t{.handle = 0}; + auto& completion_signal = _packet_data.completion_signal; + auto& interrupt_signal = _packet_data.interrupt_signal; if(injected_end_pkt) { // Adding a barrier packet with the original packet's completion signal. - queue.create_signal(0, &interrupt_signal); + queue.create_signal(0, &interrupt_signal, false); completion_signal = interrupt_signal; - transformed_packets.back().ext_amd_aql_pm4.completion_signal = interrupt_signal; + transformed_packets.back().kernel_dispatch.completion_signal = interrupt_signal; CreateBarrierPacket(&interrupt_signal, &interrupt_signal, transformed_packets); } else { - completion_signal = kernel_pkt.kernel_dispatch.completion_signal; + completion_signal = kernel_packet.kernel_dispatch.completion_signal; get_core_table()->hsa_signal_store_screlease_fn(completion_signal, 0); } @@ -474,32 +616,49 @@ WriteInterceptor(const void* packets, // signal completes. { - Queue::queue_info_session_t info_session{.queue = queue, - .inst_pkt = std::move(inst_pkt), - .interrupt_signal = interrupt_signal, - .tid = thr_id, - .enqueue_ts = common::timestamp_ns(), - .user_data = user_data, - .correlation_id = corr_id, - .kernel_pkt = kernel_pkt, - .callback_record = callback_record, - .tracing_data = tracing_data_v, - .is_serialized = bRequest_Serialize}; - - auto shared = std::make_shared(std::move(info_session)); - - queue.signal_async_handler(completion_signal, - new std::shared_ptr(shared)); - - auto tracer_data = callback_record; - tracing::execute_phase_exit_callbacks(tracing_data_v.callback_contexts, - tracing_data_v.external_correlation_ids, - ROCPROFILER_CALLBACK_TRACING_KERNEL_DISPATCH, - ROCPROFILER_KERNEL_DISPATCH_ENQUEUE, - tracer_data); + // auto info_session = info_session_t{.queue = queue, + // .inst_pkt = std::move(inst_pkt), + // .interrupt_signal = interrupt_signal, + // .tid = thr_id, + // .enqueue_ts = common::timestamp_ns(), + // .user_data = user_data, + // .correlation_id = corr_id, + // .kernel_packet = kernel_packet, + // .callback_record = callback_record, + // .tracing_data = tracing_data_v, + // .is_serialized = bRequest_Serialize}; + + _info_session.packet_data.emplace_back(std::move(_packet_data)); + + // auto shared = std::make_shared(std::move(info_session)); + + // queue.signal_async_handler( + // pooled_signal, completion_signal, new std::shared_ptr(shared)); + + auto tracer_data = _packet_data.callback_record; + tracing::execute_phase_exit_callbacks( + _packet_data.tracing_data.callback_contexts, + _packet_data.tracing_data.external_correlation_ids, + ROCPROFILER_CALLBACK_TRACING_KERNEL_DISPATCH, + ROCPROFILER_KERNEL_DISPATCH_ENQUEUE, + tracer_data); } } + using info_session_t = queue_info_session_t; + + if(!_info_session.packet_data.empty()) + { + auto* last_pooled_signal = _info_session.packet_data.back().pooled_signal; + auto last_completion_signal = _info_session.packet_data.back().completion_signal; + + auto shared = std::make_shared(std::move(_info_session)); + + queue.signal_async_handler(last_pooled_signal, + last_completion_signal, + new std::shared_ptr(shared)); + } + // Command is only executed if GLOG_v=2 or higher, otherwise it is a no-op ROCP_TRACE << fmt::format( "QueueID {}: {}", queue.get_id().handle, fmt::join(transformed_packets, fmt::format(" "))); @@ -556,7 +715,7 @@ Queue::Queue(const AgentCache& agent, aql::set_profiler_active_on_queue( _agent.cpu_pool(), _agent.get_hsa_agent(), [&](hsa::rocprofiler_packet pkt) { hsa_signal_t completion; - create_signal(0, &completion); + create_signal(0, &completion, false); pkt.ext_amd_aql_pm4.completion_signal = completion; counters::submitPacket(_intercept_queue, &pkt); constexpr auto timeout_hint = @@ -584,9 +743,9 @@ Queue::Queue(const AgentCache& agent, _ext_api.hsa_amd_queue_intercept_register_fn(_intercept_queue, WriteInterceptor, this)) << "Could not register interceptor"; - create_signal(0, &ready_signal); - create_signal(0, &block_signal); - create_signal(0, &_active_kernels); + create_signal(0, &ready_signal, false); + create_signal(0, &block_signal, false); + create_signal(0, &_active_kernels, false); _core_api.hsa_signal_store_screlease_fn(ready_signal, 0); _core_api.hsa_signal_store_screlease_fn(_active_kernels, 0); *queue = _intercept_queue; @@ -615,7 +774,7 @@ Queue::Queue( aql::set_profiler_active_on_queue( _agent.cpu_pool(), _agent.get_hsa_agent(), [&](hsa::rocprofiler_packet pkt) { hsa_signal_t completion; - create_signal(0, &completion); + create_signal(0, &completion, false); pkt.ext_amd_aql_pm4.completion_signal = completion; counters::submitPacket(_intercept_queue, &pkt); constexpr auto timeout_hint = @@ -634,9 +793,9 @@ Queue::Queue( set_write_interceptor(WriteInterceptor, this); - create_signal(0, &ready_signal); - create_signal(0, &block_signal); - create_signal(0, &_active_kernels); + create_signal(0, &ready_signal, false); + create_signal(0, &block_signal, false); + create_signal(0, &_active_kernels, false); _core_api.hsa_signal_store_screlease_fn(ready_signal, 0); _core_api.hsa_signal_store_screlease_fn(_active_kernels, 0); } @@ -648,27 +807,96 @@ Queue::~Queue() } void -Queue::signal_async_handler(const hsa_signal_t& signal, void* data) const +Queue::signal_async_handler(pooled_signal_t* signal, hsa_signal_t raw_signal, void* data) const { #if !defined(NDEBUG) CHECK_NOTNULL(hsa::get_queue_controller())->_debug_signals.wlock([&](auto& signals) { - signals[signal.handle] = signal; + signals[raw_signal.handle] = raw_signal; }); #endif - hsa_status_t status = _ext_api.hsa_amd_signal_async_handler_fn( - signal, HSA_SIGNAL_CONDITION_EQ, -1, AsyncSignalHandler, data); - ROCP_FATAL_IF(status != HSA_STATUS_SUCCESS && status != HSA_STATUS_INFO_BREAK) - << "Error: hsa_amd_signal_async_handler failed with error code " << status - << " :: " << hsa::get_hsa_status_string(status); + + if(signal) + { + ROCP_FATAL_IF(signal->get().value.handle != raw_signal.handle) + << fmt::format("signal handle does not match raw signal handle: {} vs {}", + signal->get().value.handle, + raw_signal.handle); + + ROCP_FATAL_IF(!signal->in_use()) + << fmt::format("pooled signal has not been acquired: hsa_signal_t(.handle={})", + signal->get().value.handle); + + // signal->get().data = data; + + // if(!signal->get().handler_is_set) + { + hsa_status_t status = _ext_api.hsa_amd_signal_async_handler_fn( + signal->get().value, HSA_SIGNAL_CONDITION_EQ, -1, AsyncSignalHandler, data); + ROCP_FATAL_IF(status != HSA_STATUS_SUCCESS && status != HSA_STATUS_INFO_BREAK) + << "Error: hsa_amd_signal_async_handler failed with error code " << status + << " :: " << hsa::get_hsa_status_string(status); + // signal->get().handler_is_set = true; + } + } + else + { + hsa_status_t status = _ext_api.hsa_amd_signal_async_handler_fn( + raw_signal, HSA_SIGNAL_CONDITION_EQ, -1, AsyncSignalHandler, data); + ROCP_FATAL_IF(status != HSA_STATUS_SUCCESS && status != HSA_STATUS_INFO_BREAK) + << "Error: hsa_amd_signal_async_handler failed with error code " << status + << " :: " << hsa::get_hsa_status_string(status); + } } -void -Queue::create_signal(uint32_t attribute, hsa_signal_t* signal) const +Queue::pooled_signal_t* +Queue::create_signal(uint32_t attribute, hsa_signal_t* signal, bool use_pool) { - hsa_status_t status = _ext_api.hsa_amd_signal_create_fn(1, 0, nullptr, attribute, signal); + if(auto* pool = get_signal_pool(); use_pool && pool && attribute == 0) + { + auto& _signal = pool->acquire(construct_hsa_signal, 0, 0, nullptr, attribute); + ROCP_FATAL_IF(!_signal.in_use()) << "Acquired signal from pool that is not in use"; + *signal = _signal.get().value; + // ROCP_INFO << fmt::format("acquired signal {} from pool: hsa_signal_t{{.handle={}}}", + // _signal.index(), + // _signal.get().value.handle); + get_core_table()->hsa_signal_store_screlease_fn(_signal.get().value, 1); + return &_signal; + } + + hsa_status_t status = + get_amd_ext_table()->hsa_amd_signal_create_fn(1, 0, nullptr, attribute, signal); ROCP_FATAL_IF(status != HSA_STATUS_SUCCESS && status != HSA_STATUS_INFO_BREAK) << "Error: hsa_amd_signal_create failed with error code " << status << " :: " << hsa::get_hsa_status_string(status); + + return nullptr; +} + +void +Queue::release_signal(pooled_signal_t* signal) +{ + if(signal && signal->in_use()) + { + // signal->get().data = nullptr; + ROCP_WARNING_IF(!signal->release()) + << fmt::format("Failed to release a pooled signal: hsa_signal_t{{.handle={}}}", + signal->get().value.handle); + ROCP_INFO << fmt::format("released signal {}: hsa_signal_t{{.handle={}}}", + signal->index(), + signal->get().value.handle); + } +} + +void +Queue::destroy_signal(pooled_signal_t* signal) +{ + release_signal(signal); + + if(signal && get_core_table() && get_core_table()->hsa_signal_destroy_fn) + { + get_core_table()->hsa_signal_destroy_fn(signal->get().value); + signal->get().value = null_hsa_signal; + } } void @@ -676,9 +904,16 @@ Queue::sync() const { if(_active_kernels.handle != 0u) { - _core_api.hsa_signal_wait_relaxed_fn( - _active_kernels, HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX, HSA_WAIT_STATE_ACTIVE); + constexpr auto timeout_hint = + std::chrono::duration_cast(std::chrono::seconds{1}); + _core_api.hsa_signal_wait_relaxed_fn(_active_kernels, + HSA_SIGNAL_CONDITION_EQ, + 0, + timeout_hint.count(), + HSA_WAIT_STATE_BLOCKED); } + + if(get_signal_pool()) get_signal_pool()->report_reuse(); } void diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue.hpp index 3c2230c6760..48def49c508 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue.hpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue.hpp @@ -22,11 +22,7 @@ #pragma once -#include -#include -#include -#include - +#include "lib/common/container/pool_object.hpp" #include "lib/common/container/small_vector.hpp" #include "lib/common/synchronized.hpp" #include "lib/rocprofiler-sdk/hsa/agent_cache.hpp" @@ -35,6 +31,11 @@ #include "lib/rocprofiler-sdk/hsa/rocprofiler_packet.hpp" #include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp" +#include +#include +#include +#include + #include #include #include @@ -68,10 +69,9 @@ enum class queue_state class Queue { public: - using context_t = context::context; - using context_array_t = common::container::small_vector; - using callback_t = void (*)(hsa_status_t status, hsa_queue_t* source, void* data); - using queue_info_session_t = queue_info_session; + using context_t = context::context; + using context_array_t = common::container::small_vector; + using callback_t = void (*)(hsa_status_t status, hsa_queue_t* source, void* data); struct pkt_and_serialize_t { @@ -79,6 +79,8 @@ class Queue bool request_serialize{false}; }; + using pooled_signal_t = common::container::pool_object; + // Function prototype used to notify consumers that a kernel has been enqueued. // Pair first: An AQL packet can be returned that will be injected into the queue. // Pair second: Boolean flag indicating the dispatch needs to be serialized. @@ -93,7 +95,8 @@ class Queue // Signals the completion of the kernel packet. using completed_cb_t = std::function&, + std::shared_ptr&, + packet_data_t&, inst_pkt_t&, kernel_dispatch::profiling_time)>; using callback_map_t = std::unordered_map>; @@ -126,8 +129,10 @@ class Queue const hsa_queue_t* intercept_queue() const { return _intercept_queue; }; virtual const AgentCache& get_agent() const { return _agent; } - void create_signal(uint32_t attribute, hsa_signal_t* signal) const; - void signal_async_handler(const hsa_signal_t& signal, void* data) const; + void signal_async_handler(pooled_signal_t* _signal, hsa_signal_t raw_signal, void* data) const; + static pooled_signal_t* create_signal(uint32_t attribute, hsa_signal_t* _signal, bool use_pool); + static void release_signal(pooled_signal_t* signal); + static void destroy_signal(pooled_signal_t* signal); template void signal_callback(FuncT&& func) const; @@ -154,14 +159,15 @@ class Queue void register_callback(ClientID id, queue_cb_t enqueue_cb, completed_cb_t complete_cb); void remove_callback(ClientID id); - const CoreApiTable& core_api() const { return _core_api; } - const AmdExtTable& ext_api() const { return _ext_api; } - mutable std::mutex cv_mutex; - mutable std::condition_variable cv_ready_signal; - hsa_signal_t block_signal; - hsa_signal_t ready_signal; - queue_state get_state() const; - void set_state(queue_state state); + const CoreApiTable& core_api() const { return _core_api; } + const AmdExtTable& ext_api() const { return _ext_api; } + queue_state get_state() const; + void set_state(queue_state state); + + mutable std::mutex cv_mutex = {}; + mutable std::condition_variable cv_ready_signal = {}; + hsa_signal_t block_signal = {.handle = 0}; + hsa_signal_t ready_signal = {.handle = 0}; private: std::atomic _notifiers = {0}; @@ -172,8 +178,8 @@ class Queue common::Synchronized _callbacks = {}; hsa_queue_t* _intercept_queue = nullptr; queue_state _state = queue_state::normal; - std::mutex _lock_queue; - hsa_signal_t _active_kernels = {.handle = 0}; + std::mutex _lock_queue = {}; + hsa_signal_t _active_kernels = {.handle = 0}; }; inline rocprofiler_queue_id_t diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue_info_session.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue_info_session.hpp index 9cc02590687..d311af26a91 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue_info_session.hpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/queue_info_session.hpp @@ -22,14 +22,16 @@ #pragma once -#include -#include - +#include "lib/common/container/pool_object.hpp" #include "lib/common/container/small_vector.hpp" #include "lib/common/utility.hpp" #include "lib/rocprofiler-sdk/hsa/rocprofiler_packet.hpp" +#include "lib/rocprofiler-sdk/hsa/signal.hpp" #include "lib/rocprofiler-sdk/tracing/fwd.hpp" +#include +#include + #include namespace rocprofiler @@ -44,27 +46,38 @@ namespace hsa { class Queue; +// per-packet data +struct packet_data_t +{ + using callback_record_t = rocprofiler_callback_tracing_kernel_dispatch_data_t; + using pooled_signal_t = common::container::pool_object; + + tracing::tracing_data tracing_data = {}; + rocprofiler_packet kernel_packet = {}; + inst_pkt_t instrumentation_packets = {}; + hsa_signal_t completion_signal = {.handle = 0}; + hsa_signal_t interrupt_signal = {.handle = 0}; + callback_record_t callback_record = {}; + rocprofiler_user_data_t user_data = {.value = 0}; + pooled_signal_t* pooled_signal = nullptr; + bool is_serialized = false; +}; + // Internal session information that is used by write interceptor // to track state of the intercepted kernel. -struct queue_info_session +struct queue_info_session_t { using context_t = context::context; - using user_data_map_t = std::unordered_map; + using user_data_map_t = tracing::external_correlation_id_map_t; using external_corr_id_map_t = user_data_map_t; - using callback_record_t = rocprofiler_callback_tracing_kernel_dispatch_data_t; using context_array_t = common::container::small_vector; + using packet_data_array_t = common::container::small_vector; Queue& queue; - inst_pkt_t inst_pkt = {}; - hsa_signal_t interrupt_signal = {}; - rocprofiler_thread_id_t tid = common::get_tid(); - rocprofiler_timestamp_t enqueue_ts = 0; - rocprofiler_user_data_t user_data = {.value = 0}; - context::correlation_id* correlation_id = nullptr; - rocprofiler_packet kernel_pkt = {}; - callback_record_t callback_record = {}; - tracing::tracing_data tracing_data = {}; - bool is_serialized = false; + rocprofiler_thread_id_t tid = common::get_tid(); + rocprofiler_timestamp_t enqueue_ts = 0; + context::correlation_id* correlation_id = nullptr; + packet_data_array_t packet_data = {}; }; } // namespace hsa } // namespace rocprofiler diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/signal.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/signal.hpp new file mode 100644 index 00000000000..ec6004da95d --- /dev/null +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/hsa/signal.hpp @@ -0,0 +1,39 @@ +// MIT License +// +// Copyright (c) 2023-2025 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include + +namespace rocprofiler +{ +namespace hsa +{ +// pair of hsa signal and user data pointer for async handler +struct signal_t +{ + // bool handler_is_set = false; + hsa_signal_t value = {.handle = 0}; + // void* data = nullptr; +}; +} // namespace hsa +} // namespace rocprofiler diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/kernel_dispatch/tracing.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/kernel_dispatch/tracing.cpp index e05a51e7542..77dcac7610b 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/kernel_dispatch/tracing.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/kernel_dispatch/tracing.cpp @@ -41,19 +41,13 @@ namespace rocprofiler { namespace kernel_dispatch { -namespace -{ -using queue_info_session_t = hsa::queue_info_session; -using kernel_dispatch_record_t = rocprofiler_buffer_tracing_kernel_dispatch_record_t; -} // namespace - profiling_time -get_dispatch_time(const hsa::queue_info_session& session) +get_dispatch_time(const queue_info_session_t& session, packet_data_t& packet_data) { - const auto& callback_record = session.callback_record; + const auto& callback_record = packet_data.callback_record; const auto* _rocp_agent = agent::get_agent(callback_record.dispatch_info.agent_id); auto _hsa_agent = agent::get_hsa_agent(_rocp_agent); - auto _signal = session.kernel_pkt.kernel_dispatch.completion_signal; + auto _signal = packet_data.kernel_packet.kernel_dispatch.completion_signal; auto _kern_id = callback_record.dispatch_info.kernel_id; return (_hsa_agent) ? get_dispatch_time(*_hsa_agent, _signal, _kern_id, session.enqueue_ts) @@ -61,18 +55,22 @@ get_dispatch_time(const hsa::queue_info_session& session) } void -dispatch_complete(queue_info_session_t& session, profiling_time dispatch_time) +dispatch_complete(queue_info_session_t& session, + packet_data_t& packet_data, + profiling_time dispatch_time) { + using kernel_dispatch_record_t = rocprofiler_buffer_tracing_kernel_dispatch_record_t; + // get the contexts that were active when the signal was created - auto& tracing_data_v = session.tracing_data; + auto& tracing_data_v = packet_data.tracing_data; if(tracing_data_v.callback_contexts.empty() && tracing_data_v.buffered_contexts.empty()) return; // we need to decrement this reference count at the end of the functions auto* _corr_id = session.correlation_id; // only do the following work if there are contexts that require this info - auto& callback_record = session.callback_record; - const auto& _extern_corr_ids = session.tracing_data.external_correlation_ids; + auto& callback_record = packet_data.callback_record; + const auto& _extern_corr_ids = packet_data.tracing_data.external_correlation_ids; auto _tid = session.tid; auto _internal_corr_id = (_corr_id) ? _corr_id->internal : 0; auto _ancestor_corr_id = (_corr_id) ? _corr_id->ancestor : 0; diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/kernel_dispatch/tracing.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/kernel_dispatch/tracing.hpp index 27f83c97489..ff5592d5d3d 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/kernel_dispatch/tracing.hpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/kernel_dispatch/tracing.hpp @@ -42,15 +42,16 @@ struct context; namespace kernel_dispatch { using context_t = context::context; -using user_data_map_t = std::unordered_map; +using user_data_map_t = tracing::external_correlation_id_map_t; using external_corr_id_map_t = user_data_map_t; - -using profiling_time = tracing::profiling_time; +using queue_info_session_t = hsa::queue_info_session_t; +using packet_data_t = hsa::packet_data_t; +using profiling_time = tracing::profiling_time; profiling_time -get_dispatch_time(const hsa::queue_info_session& session); +get_dispatch_time(const queue_info_session_t& session, packet_data_t& packet_data); void -dispatch_complete(hsa::queue_info_session&, profiling_time); +dispatch_complete(queue_info_session_t& session, packet_data_t& packet_data, profiling_time); } // namespace kernel_dispatch } // namespace rocprofiler diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/pc_sampling/hsa_adapter.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/pc_sampling/hsa_adapter.cpp index f9da46d0cb9..b5be20e61b3 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/pc_sampling/hsa_adapter.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/pc_sampling/hsa_adapter.cpp @@ -125,7 +125,7 @@ amd_intercept_marker_handler_callback(const struct amd_aql_intercept_marker_s* p void kernel_completion_cb(const rocprofiler_agent_t* rocp_agent, rocprofiler::hsa::rocprofiler_packet& /*kernel_pkt*/, - const rocprofiler::hsa::Queue::queue_info_session_t& session) + const rocprofiler::hsa::queue_info_session_t& session) { // No internal correlation IDs, meaning there is no need to call CID manager. if(!session.correlation_id) return; @@ -225,8 +225,8 @@ generate_marker_packet_for_kernel( // Get an external correlation that corresponds to the context // enclosing PC sampling service. - auto external_corr = tracing::empty_user_data; - auto external_corr_it = external_correlation_ids.find(pcs_context); + auto external_corr = tracing::empty_user_data; + const auto* external_corr_it = external_correlation_ids.find(pcs_context); if(external_corr_it != external_correlation_ids.end()) external_corr = external_corr_it->second; marker_pkt.user_data[1] = external_corr.value; @@ -349,14 +349,15 @@ pc_sampling_service_finish_configuration(context::pc_sampling_service* service) rocprofiler_kernel_id_t /*kernel_id*/, rocprofiler_dispatch_id_t /*dispatch_id*/, rocprofiler_user_data_t*, - const rocprofiler::hsa::Queue::queue_info_session_t::external_corr_id_map_t&, + const rocprofiler::hsa::queue_info_session_t::external_corr_id_map_t&, const context::correlation_id*) { return rocprofiler::hsa::Queue::pkt_and_serialize_t{}; }, // Completion CB - [](const rocprofiler::hsa::Queue& q, - rocprofiler::hsa::rocprofiler_packet kern_pkt, - std::shared_ptr& session, + [](const rocprofiler::hsa::Queue& q, + rocprofiler::hsa::rocprofiler_packet kern_pkt, + std::shared_ptr& session, + rocprofiler::hsa::packet_data_t& /*packet*/, rocprofiler::hsa::inst_pkt_t&, kernel_dispatch::profiling_time) { kernel_completion_cb(q.get_agent().get_rocp_agent(), kern_pkt, *session); diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/pc_sampling/tests/pc_sampling_internals.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/pc_sampling/tests/pc_sampling_internals.hpp index a7f68709ce0..f91e4430190 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/pc_sampling/tests/pc_sampling_internals.hpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/pc_sampling/tests/pc_sampling_internals.hpp @@ -42,11 +42,11 @@ amd_intercept_marker_handler_callback(const struct amd_aql_intercept_marker_s* p extern void kernel_completion_cb(const std::shared_ptr& info, - const rocprofiler_agent_t* rocp_agent, - rocprofiler::hsa::ClientID client_id, - const rocprofiler::hsa::rocprofiler_packet& kernel_pkt, - const rocprofiler::hsa::Queue::queue_info_session_t& session, - std::unique_ptr pkt); + const rocprofiler_agent_t* rocp_agent, + rocprofiler::hsa::ClientID client_id, + const rocprofiler::hsa::rocprofiler_packet& kernel_pkt, + const rocprofiler::hsa::queue_info_session_t& session, + std::unique_ptr pkt); extern void data_ready_callback(void* client_callback_data, diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/thread_trace/core.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/thread_trace/core.cpp index d24d6a177cc..a16ab661baf 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/thread_trace/core.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/thread_trace/core.cpp @@ -360,8 +360,9 @@ DispatchThreadTracer::pre_kernel_call(const hsa::Queue& queue, } void -DispatchThreadTracer::post_kernel_call(DispatchThreadTracer::inst_pkt_t& aql, - const hsa::Queue::queue_info_session_t& session) +DispatchThreadTracer::post_kernel_call(DispatchThreadTracer::inst_pkt_t& aql, + const hsa::queue_info_session_t& /*session*/, + const hsa::packet_data_t& packet_data) { if(post_move_data.load() < 1) return; @@ -377,14 +378,14 @@ DispatchThreadTracer::post_kernel_call(DispatchThreadTracer::inst_pkt_t& a auto it = agents.find(pkt->GetAgent()); if(it != agents.end() && it->second != nullptr) - it->second->iterate_data(pkt->GetHandle(), session.user_data); + it->second->iterate_data(pkt->GetHandle(), packet_data.user_data); } } void DispatchThreadTracer::start_context() { - using corr_id_map_t = hsa::Queue::queue_info_session_t::external_corr_id_map_t; + using corr_id_map_t = hsa::queue_info_session_t::external_corr_id_map_t; CHECK_NOTNULL(hsa::get_queue_controller())->enable_serialization(); @@ -407,10 +408,11 @@ DispatchThreadTracer::start_context() }, [=](const hsa::Queue& /* q */, hsa::rocprofiler_packet /* kern_pkt */, - std::shared_ptr& session, - inst_pkt_t& aql, + std::shared_ptr& session, + hsa::packet_data_t& packet_data, + inst_pkt_t& aql, kernel_dispatch::profiling_time) { - this->post_kernel_call(aql, *session); + this->post_kernel_call(aql, *session, packet_data); }); }); } diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/thread_trace/core.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/thread_trace/core.hpp index f27a874cfc2..2b12ed2f5a2 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/thread_trace/core.hpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/thread_trace/core.hpp @@ -150,7 +150,9 @@ class DispatchThreadTracer rocprofiler_user_data_t* user_data, const context::correlation_id* corr_id); - void post_kernel_call(inst_pkt_t& aql, const hsa::queue_info_session& session); + void post_kernel_call(inst_pkt_t& aql, + const hsa::queue_info_session_t& session, + const hsa::packet_data_t& packet_data); const auto& get_agents() const { return agents; } private: diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tracing/fwd.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tracing/fwd.hpp index 1a08194aa6f..03edd3e3962 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tracing/fwd.hpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tracing/fwd.hpp @@ -40,11 +40,12 @@ struct correlation_tracing_service; namespace tracing { template -using small_vector_t = common::container::small_vector; -using correlation_service = context::correlation_tracing_service; -using context_t = context::context; -using context_array_t = common::container::small_vector; -using external_correlation_id_map_t = std::unordered_map; +using small_vector_t = common::container::small_vector; +using correlation_service = context::correlation_tracing_service; +using context_t = context::context; +using context_array_t = common::container::small_vector; +using external_correlation_id_map_t = + common::container::small_vector, 8>; constexpr auto context_data_vec_size = 2; constexpr auto empty_user_data = rocprofiler_user_data_t{.value = 0}; diff --git a/projects/rocprofiler-sdk/tests/bin/CMakeLists.txt b/projects/rocprofiler-sdk/tests/bin/CMakeLists.txt index 6f447605477..c8a0598e9d8 100644 --- a/projects/rocprofiler-sdk/tests/bin/CMakeLists.txt +++ b/projects/rocprofiler-sdk/tests/bin/CMakeLists.txt @@ -45,3 +45,4 @@ add_subdirectory(attachment-test) add_subdirectory(hip-host) add_subdirectory(module-loading-test) add_subdirectory(late-start-tracing) +add_subdirectory(hip-graph-bubbles) diff --git a/projects/rocprofiler-sdk/tests/bin/hip-graph-bubbles/CMakeLists.txt b/projects/rocprofiler-sdk/tests/bin/hip-graph-bubbles/CMakeLists.txt new file mode 100644 index 00000000000..0bfc8dc111d --- /dev/null +++ b/projects/rocprofiler-sdk/tests/bin/hip-graph-bubbles/CMakeLists.txt @@ -0,0 +1,46 @@ +# +# +# +cmake_minimum_required(VERSION 3.21.0 FATAL_ERROR) + +if(NOT CMAKE_HIP_COMPILER) + find_program( + amdclangpp_EXECUTABLE + NAMES amdclang++ + HINTS ${ROCM_PATH} ENV ROCM_PATH /opt/rocm + PATHS ${ROCM_PATH} ENV ROCM_PATH /opt/rocm + PATH_SUFFIXES bin llvm/bin NO_CACHE) + mark_as_advanced(amdclangpp_EXECUTABLE) + + if(amdclangpp_EXECUTABLE) + set(CMAKE_HIP_COMPILER "${amdclangpp_EXECUTABLE}") + endif() +endif() + +project(rocprofiler-sdk-tests-bin-hip-graph-bubbles LANGUAGES CXX HIP) + +foreach(_TYPE DEBUG MINSIZEREL RELEASE RELWITHDEBINFO) + if("${CMAKE_HIP_FLAGS_${_TYPE}}" STREQUAL "") + set(CMAKE_HIP_FLAGS_${_TYPE} "${CMAKE_CXX_FLAGS_${_TYPE}}") + endif() +endforeach() + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_HIP_STANDARD 17) +set(CMAKE_HIP_EXTENSIONS OFF) +set(CMAKE_HIP_STANDARD_REQUIRED ON) + +set_source_files_properties(hip-graph-bubbles.cpp PROPERTIES LANGUAGE HIP) +add_executable(hip-graph-bubbles) +target_sources(hip-graph-bubbles PRIVATE hip-graph-bubbles.cpp) +target_compile_options(hip-graph-bubbles PRIVATE -W -Wall -Wextra -Wpedantic -Wshadow + -Werror) + +find_package(Threads REQUIRED) +target_link_libraries(hip-graph-bubbles PRIVATE Threads::Threads) + +find_package(rocprofiler-sdk-roctx REQUIRED) +target_link_libraries(hip-graph-bubbles + PRIVATE rocprofiler-sdk-roctx::rocprofiler-sdk-roctx) diff --git a/projects/rocprofiler-sdk/tests/bin/hip-graph-bubbles/hip-graph-bubbles.cpp b/projects/rocprofiler-sdk/tests/bin/hip-graph-bubbles/hip-graph-bubbles.cpp new file mode 100644 index 00000000000..52c280cd772 --- /dev/null +++ b/projects/rocprofiler-sdk/tests/bin/hip-graph-bubbles/hip-graph-bubbles.cpp @@ -0,0 +1,143 @@ +/* +Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#include +#include + +#include +#include +#include +#include + +#define HIP_CHECK(cmd) \ + { \ + hipError_t error = cmd; \ + if(error != hipSuccess) \ + { \ + std::cerr << "HIP error: " << hipGetErrorString(error) << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +// Simple kernel that does minimal work +__global__ void +simpleKernel(int* data, int value) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + data[idx] = value + idx; +} + +int +main() +{ + const int NUM_KERNELS = 2000; + const int NUM_ITERATIONS = 200; + const int ARRAY_SIZE = 256; + + std::cout << "Creating HIP graph with " << NUM_KERNELS << " kernel launches" << std::endl; + std::cout << "Will execute graph " << NUM_ITERATIONS << " times" << std::endl; + + // Allocate device memory + int* d_data; + HIP_CHECK(hipMalloc(&d_data, ARRAY_SIZE * sizeof(int))); + + // Create graph + hipGraph_t graph; + HIP_CHECK(hipGraphCreate(&graph, 0)); + + // Create stream for graph capture + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + // Begin graph capture + HIP_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + + // Launch many kernels + dim3 blockSize(256); + dim3 gridSize(1); + + for(int i = 0; i < NUM_KERNELS; i++) + { + hipLaunchKernelGGL(simpleKernel, gridSize, blockSize, 0, stream, d_data, i); + } + + // End graph capture + HIP_CHECK(hipStreamEndCapture(stream, &graph)); + + // Create executable graph + hipGraphExec_t graphExec; + HIP_CHECK(hipGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0)); + + std::cout << "Graph created and instantiated successfully" << std::endl; + std::cout << "Starting graph execution loop..." << std::endl; + + // Start timing + auto start = std::chrono::high_resolution_clock::now(); + + // Execute the graph multiple times + for(int iter = 0; iter < NUM_ITERATIONS; iter++) + { + roctxRangePush("graph_launch"); + HIP_CHECK(hipGraphLaunch(graphExec, stream)); + roctxRangePop(); + + if((iter + 1) % 50 == 0) + { + std::cout << "Completed " << (iter + 1) << " iterations" << std::endl; + } + + // Wait for completion + // if(iter % 5 == 4) + HIP_CHECK(hipStreamSynchronize(stream)); + + // std::cout << "Synchronized after iteration " << (iter + 1) << std::endl; + // std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Sleep to simulate work + // between iterations + } + + // Wait for completion + HIP_CHECK(hipStreamSynchronize(stream)); + + // End timing + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + + std::cout << "All iterations completed successfully!" << std::endl; + std::cout << std::fixed << std::setprecision(4); + std::cout << "\n=== Timing Results ===" << std::endl; + std::cout << "Total execution time: " << elapsed.count() << " seconds" << std::endl; + std::cout << "Total kernel launches: " << (NUM_KERNELS * NUM_ITERATIONS) << std::endl; + std::cout << "Average time per iteration: " << (elapsed.count() / NUM_ITERATIONS) << " seconds" + << std::endl; + std::cout << "======================" << std::endl; + + // Cleanup + HIP_CHECK(hipGraphExecDestroy(graphExec)); + HIP_CHECK(hipGraphDestroy(graph)); + HIP_CHECK(hipStreamDestroy(stream)); + HIP_CHECK(hipFree(d_data)); + + std::cout << "Test completed successfully" << std::endl; + + return 0; +}