@@ -139,14 +139,19 @@ public:
139139 const ucp_perf_cuda_params &get_params () const { return m_params; }
140140
141141private:
142+ static bool has_counter (const ucx_perf_context_t &perf)
143+ {
144+ return (perf.params .command != UCX_PERF_CMD_PUT_SINGLE);
145+ }
146+
142147 void init_mem_list (const ucx_perf_context_t &perf)
143148 {
144- /* +1 for the counter */
145- size_t count = perf. params . msg_size_cnt + 1 ;
146- size_t offset = 0 ;
149+ size_t data_count = perf. params . msg_size_cnt ;
150+ size_t count = data_count + ( has_counter (perf) ? 1 : 0 ) ;
151+ size_t offset = 0 ;
147152 ucp_device_mem_list_elem_t elems[count];
148153
149- for (size_t i = 0 ; i < count ; ++i) {
154+ for (size_t i = 0 ; i < data_count ; ++i) {
150155 elems[i].field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH |
151156 UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
152157 UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR |
@@ -156,11 +161,19 @@ private:
156161 elems[i].rkey = perf.ucp .rkey ;
157162 elems[i].local_addr = UCS_PTR_BYTE_OFFSET (perf.send_buffer , offset);
158163 elems[i].remote_addr = perf.ucp .remote_addr + offset;
159- elems[i].length = (i == count - 1 ) ? ONESIDED_SIGNAL_SIZE :
160- perf.params .msg_size_list [i];
164+ elems[i].length = perf.params .msg_size_list [i];
161165 offset += elems[i].length ;
162166 }
163167
168+ if (has_counter (perf)) {
169+ elems[data_count].field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
170+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
171+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
172+ elems[data_count].rkey = perf.ucp .rkey ;
173+ elems[data_count].remote_addr = perf.ucp .remote_addr + offset;
174+ elems[data_count].length = ONESIDED_SIGNAL_SIZE;
175+ }
176+
164177 ucp_device_mem_list_params_t params;
165178 params.field_mask = UCP_DEVICE_MEM_LIST_PARAMS_FIELD_ELEMENTS |
166179 UCP_DEVICE_MEM_LIST_PARAMS_FIELD_ELEMENT_SIZE |
@@ -178,18 +191,22 @@ private:
178191
179192 void init_elements (const ucx_perf_context_t &perf)
180193 {
181- /* +1 for the counter */
182- size_t count = perf. params . msg_size_cnt + 1 ;
194+ size_t data_count = perf. params . msg_size_cnt ;
195+ size_t count = data_count + ( has_counter (perf) ? 1 : 0 ) ;
183196
184197 std::vector<unsigned > indices (count);
185198 std::vector<size_t > local_offsets (count, 0 );
186199 std::vector<size_t > remote_offsets (count, 0 );
187200 std::vector<size_t > lengths (count);
188201
189- for (unsigned i = 0 ; i < count ; ++i) {
202+ for (unsigned i = 0 ; i < data_count ; ++i) {
190203 indices[i] = i;
191- lengths[i] = (i == count - 1 ) ? ONESIDED_SIGNAL_SIZE :
192- perf.params .msg_size_list [i];
204+ lengths[i] = perf.params .msg_size_list [i];
205+ }
206+
207+ if (has_counter (perf)) {
208+ indices[data_count] = data_count;
209+ lengths[data_count] = ONESIDED_SIGNAL_SIZE;
193210 }
194211
195212 m_params.indices = device_vector (indices);
0 commit comments