@@ -109,28 +109,41 @@ public:
109109 const ucp_perf_cuda_params &get_params () const { return m_params; }
110110
111111private:
112+ static bool has_counter (const ucx_perf_context_t &perf)
113+ {
114+ return (perf.params .command != UCX_PERF_CMD_PUT_SINGLE);
115+ }
116+
112117 void init_mem_list (const ucx_perf_context_t &perf)
113118 {
114- /* +1 for the counter */
115- size_t count = perf. params . msg_size_cnt + 1 ;
116- size_t offset = 0 ;
119+ size_t data_count = perf. params . msg_size_cnt ;
120+ size_t count = data_count + ( has_counter (perf) ? 1 : 0 ) ;
121+ size_t offset = 0 ;
117122 ucp_device_mem_list_elem_t elems[count];
118123
119- for (size_t i = 0 ; i < count ; ++i) {
120- elems[i].field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH |
121- UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
122- UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR |
123- UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
124- UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
125- elems[i].memh = perf.ucp .send_memh ;
126- elems[i].rkey = perf.ucp .rkey ;
127- elems[i].local_addr = UCS_PTR_BYTE_OFFSET (perf.send_buffer , offset);
124+ for (size_t i = 0 ; i < data_count ; ++i) {
125+ elems[i].field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH |
126+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
127+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR |
128+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
129+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
130+ elems[i].memh = perf.ucp .send_memh ;
131+ elems[i].rkey = perf.ucp .rkey ;
132+ elems[i].local_addr = UCS_PTR_BYTE_OFFSET (perf.send_buffer , offset);
128133 elems[i].remote_addr = perf.ucp .remote_addr + offset;
129- elems[i].length = (i == count - 1 ) ? ONESIDED_SIGNAL_SIZE :
130- perf.params .msg_size_list [i];
134+ elems[i].length = perf.params .msg_size_list [i];
131135 offset += elems[i].length ;
132136 }
133137
138+ if (has_counter (perf)) {
139+ elems[data_count].field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
140+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
141+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
142+ elems[data_count].rkey = perf.ucp .rkey ;
143+ elems[data_count].remote_addr = perf.ucp .remote_addr + offset;
144+ elems[data_count].length = ONESIDED_SIGNAL_SIZE;
145+ }
146+
134147 ucp_device_mem_list_params_t params;
135148 params.field_mask = UCP_DEVICE_MEM_LIST_PARAMS_FIELD_ELEMENTS |
136149 UCP_DEVICE_MEM_LIST_PARAMS_FIELD_ELEMENT_SIZE |
@@ -148,20 +161,22 @@ private:
148161
149162 void init_elements (const ucx_perf_context_t &perf)
150163 {
151- /* +1 for the counter */
152- size_t count = perf.params .msg_size_cnt + 1 ;
153- size_t offset = 0 ;
164+ size_t data_count = perf.params .msg_size_cnt ;
165+ size_t count = data_count + (has_counter (perf) ? 1 : 0 );
154166
155167 std::vector<unsigned > indices (count);
156168 std::vector<size_t > local_offsets (count, 0 );
157169 std::vector<size_t > remote_offsets (count, 0 );
158170 std::vector<size_t > lengths (count);
159171
160- for (unsigned i = 0 ; i < count ; ++i) {
172+ for (unsigned i = 0 ; i < data_count ; ++i) {
161173 indices[i] = i;
162- lengths[i] = (i == count - 1 ) ? ONESIDED_SIGNAL_SIZE :
163- perf.params .msg_size_list [i];
164- offset += lengths[i];
174+ lengths[i] = perf.params .msg_size_list [i];
175+ }
176+
177+ if (has_counter (perf)) {
178+ indices[data_count] = data_count;
179+ lengths[data_count] = ONESIDED_SIGNAL_SIZE;
165180 }
166181
167182 device_clone (&m_params.indices , indices.data (), count);
0 commit comments