@@ -165,7 +165,10 @@ namespace ReSolve
165165 * @param[in] memspace String containg memspace (cpu or cuda or hip)
166166 *
167167 */
168- void VectorHandler::axpy (const real_type alpha, /* const */ vector::Vector* x, vector::Vector* y, memory::MemorySpace memspace)
168+ void VectorHandler::axpy (const real_type alpha,
169+ /* const */ vector::Vector* x,
170+ vector::Vector* y,
171+ memory::MemorySpace memspace)
169172 {
170173 // AXPY: y = alpha * x + y
171174 using namespace ReSolve ::memory;
@@ -182,8 +185,16 @@ namespace ReSolve
182185 }
183186
184187 /* *
185- * @brief gemv computes matrix-vector product where both matrix and vectors are dense.
186- * i.e., x = beta*x + alpha*V*y
188+ * @brief gemv computes dense matrix-vector product.
189+ *
190+ * In Re::Solve applications, gemv is used to compute dot products of
191+ * multivectors.
192+ *
193+ * If `transpose = N` (no), `x := beta*x + alpha*V*y`,
194+ * where `x` is `[n x 1]`, `V` is `[n x k]` and `y` is `[k x 1]`.
195+ * If `transpose = T` (yes), `x := beta*x + alpha*V^T*y`,
196+ * where `x` is `[k x 1]`, `V` is `[n x k]` and `y` is `[n x 1]`.
197+ *
187198 *
188199 * @param[in] Transpose - yes (T) or no (N)
189200 * @param[in] n - Number of rows in (non-transposed) matrix
@@ -198,26 +209,24 @@ namespace ReSolve
198209 * @pre V is stored colum-wise, _n_ > 0, _k_ > 0
199210 *
200211 */
201- void VectorHandler::gemv (char transpose,
202- index_type n,
203- index_type k,
204- const real_type alpha,
205- const real_type beta,
206- vector::Vector* V,
207- vector::Vector* y,
208- vector::Vector* x,
212+ void VectorHandler::gemv (char transpose,
213+ index_type n,
214+ index_type k,
215+ const real_type alpha,
216+ const real_type beta,
217+ vector::Vector* V,
218+ vector::Vector* y,
219+ vector::Vector* x,
209220 memory::MemorySpace memspace)
210221 {
211222 using namespace ReSolve ::memory;
212223
213224 // TODO: Remove n as the argument becuase it must always be n = V->getSize()
214225 assert (n == V->getSize () && " gemv: n does not match the number of rows in V" );
215226
216- assert ((transpose == ' T' && V->getSize () == y->getSize ()) &&
217- " gemv: size mismatch! size of V^T does not match size of y" );
218-
219- assert ((transpose == ' N' && V->getSize () == x->getSize ()) &&
220- " gemv: size mismatch! size of V does not match size of x" );
227+ // TODO: These assertions should hold but they do not. Investigate more.
228+ // assert((transpose == 'T' && V->getSize() == y->getSize()) && "gemv: size mismatch! size of V^T does not match size of y");
229+ // assert(((transpose == 'N') && (V->getSize() == x->getSize())) && "gemv: size mismatch! size of V does not match size of x");
221230
222231 switch (memspace)
223232 {
@@ -232,19 +241,27 @@ namespace ReSolve
232241 }
233242
234243 /* *
235- * @brief mass (bulk) axpy i.e, y = y - x*alpha where alpha is a vector
244+ * @brief Multivector axpy: y: = y + \sum_i alpha_i x_i
236245 *
237246 * @param[in] size number of elements in y
238247 * @param[in] alpha vector size k x 1
239- * @param[in] x (multi)vector size size x k
248+ * @param[in] x (multi)vector [ size x k]
240249 * @param[in,out] y vector size size x 1 (this is where the result is stored)
241250 * @param[in] memspace string containg memspace (cpu or cuda or hip)
242251 *
243252 * @pre _k_ > 0, _size_ > 0, _size_ = x->getSize()
244253 *
245254 */
246- void VectorHandler::axpyMulti (index_type size, vector::Vector* alpha, index_type k, vector::Vector* x, vector::Vector* y, memory::MemorySpace memspace)
255+ void VectorHandler::axpyMulti (index_type size,
256+ vector::Vector* alpha,
257+ index_type k,
258+ vector::Vector* x,
259+ vector::Vector* y,
260+ memory::MemorySpace memspace)
247261 {
262+ assert (y->getSize () == x->getSize () && " Sizes of x and y must match!\n " );
263+ assert (alpha->getSize () == k && " Size of alpha must match k!\n " );
264+
248265 using namespace ReSolve ::memory;
249266 switch (memspace)
250267 {
@@ -259,21 +276,32 @@ namespace ReSolve
259276 }
260277
261278 /* *
262- * @brief mass (bulk) dot product i.e, V^T x, where V is n x k dense multivector (a dense multivector consisting of k vectors size n)
263- * and x is k x 2 dense multivector (a multivector consisiting of two vectors size n each)
279+ * @brief Multivector dot product, i.e V^T x
280+ *
281+ * Computes V^T x with k vectors from multivector V. Result is storred
282+ * in `res`.
264283 *
265284 * @param[in] size - Number of elements in a single vector in V
266- * @param[in] V - Multivector; k vectors size n x 1 each
267- * @param[in] k - Number of vectors in V
268- * @param[in] x - Multivector; 2 vectors size n x 1 each
269- * @param[out] res - Multivector; 2 vectors size k x 1 each (result is returned in res)
270- * @param[in] memspace - String containg memspace (cpu or cuda or hip)
285+ * @param[in] V - Multivector; k vectors of size n x 1 each
286+ * @param[in] k - Number of vectors in V to use
287+ * @param[in] x - Multivector; 2 vectors of size n x 1 each
288+ * @param[out] res - Multivector; 2 vectors size k x 1 each
289+ * @param[in] memspace - String containg memspace (cpu, cuda or hip)
271290 *
272- * @pre _size_ > 0, _k_ > 0, size = x->getSize(), _res_ needs to be allocated
291+ * @pre _size_ > 0, _k_ > 0, size = x->getSize().
292+ * @pre _res_ needs to be allocated to k x 2 size.
273293 *
274294 */
275- void VectorHandler::dot2Multi (index_type size, vector::Vector* V, index_type k, vector::Vector* x, vector::Vector* res, memory::MemorySpace memspace)
295+ void VectorHandler::dot2Multi (index_type size,
296+ vector::Vector* V,
297+ index_type k,
298+ vector::Vector* x,
299+ vector::Vector* res,
300+ memory::MemorySpace memspace)
276301 {
302+ assert (x->getSize () == V->getSize () && " Sizes of V and x do not match!\n " );
303+ assert (res->getSize () == k && " Size of `res` must match k!\n " );
304+
277305 using namespace ReSolve ::memory;
278306 switch (memspace)
279307 {
@@ -304,6 +332,7 @@ namespace ReSolve
304332 assert (diag->getSize () == vec->getSize () && " Diagonal vector must be of the same size as the vector." );
305333 assert (diag->getData (memspace) != nullptr && " Diagonal vector data is null!\n " );
306334 assert (vec->getData (memspace) != nullptr && " Vector data is null!\n " );
335+
307336 using namespace ReSolve ::memory;
308337 switch (memspace)
309338 {
0 commit comments