@@ -43,6 +43,7 @@ public:
4343 m_ctx (ctx),
4444 m_max_iters (ctx.max_iters),
4545 m_next_report_iter (1 ),
46+ m_last_completed (0 ),
4647 m_last_report_time (ucx_perf_cuda_get_time_ns()),
4748 m_report_interval_ns (ctx.report_interval_ns / UPDATES_PER_INTERVAL)
4849 {
@@ -51,14 +52,15 @@ public:
5152 __device__ inline void
5253 update_report (ucx_perf_counter_t completed)
5354 {
54- if ((threadIdx .x == 0 ) && ucs_unlikely (completed >= m_next_report_iter)) {
55- assert (completed - m_ctx. completed_iters > 0 );
55+ if ((blockIdx . x == 0 ) && ( threadIdx .x == 0 ) && ucs_unlikely (completed >= m_next_report_iter)) {
56+ assert (completed - m_last_completed > 0 );
5657 ucx_perf_cuda_time_t cur_time = ucx_perf_cuda_get_time_ns ();
5758 ucx_perf_cuda_time_t iter_time = (cur_time - m_last_report_time) /
58- (completed - m_ctx. completed_iters );
59+ (completed - m_last_completed );
5960 assert (iter_time > 0 );
61+ m_last_completed = completed;
6062 m_last_report_time = cur_time;
61- m_ctx.completed_iters = completed;
63+ m_ctx.completed_iters = completed * gridDim . x ;
6264 __threadfence_system ();
6365
6466 m_next_report_iter = ucs_min (completed + (m_report_interval_ns / iter_time),
@@ -70,6 +72,7 @@ private:
7072 ucx_perf_cuda_context &m_ctx;
7173 ucx_perf_counter_t m_max_iters;
7274 ucx_perf_counter_t m_next_report_iter;
75+ ucx_perf_counter_t m_last_completed;
7376 ucx_perf_cuda_time_t m_last_report_time;
7477 ucx_perf_cuda_time_t m_report_interval_ns;
7578};
@@ -179,6 +182,7 @@ public:
179182 ucx_perf_counter_t last_completed = 0 ;
180183 ucx_perf_counter_t completed = m_cpu_ctx->completed_iters ;
181184 unsigned thread_count = m_perf.params .device_thread_count ;
185+ unsigned block_count = m_perf.params .device_block_count ;
182186 ucs_device_level_t level = m_perf.params .device_level ;
183187 unsigned msgs_per_iter;
184188 UCX_PERF_SWITCH_LEVEL (level, UCX_PERF_THREAD_INDEX_SET, thread_count,
@@ -189,7 +193,7 @@ public:
189193 if (delta > 0 ) {
190194 // TODO: calculate latency percentile on kernel
191195 ucx_perf_update (&m_perf, delta, delta * msgs_per_iter, msg_length);
192- } else if (completed >= m_perf.max_iter ) {
196+ } else if (completed >= ( m_perf.max_iter * block_count) ) {
193197 break ;
194198 }
195199 last_completed = completed;
0 commit comments