Skip to content

Commit 7ee95c4

Browse files
committed
Update comments in VectorHandler class.
1 parent 667e365 commit 7ee95c4

File tree

2 files changed

+94
-65
lines changed

2 files changed

+94
-65
lines changed

resolve/vector/VectorHandler.cpp

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,10 @@ namespace ReSolve
165165
* @param[in] memspace String containg memspace (cpu or cuda or hip)
166166
*
167167
*/
168-
void VectorHandler::axpy(const real_type alpha, /* const */ vector::Vector* x, vector::Vector* y, memory::MemorySpace memspace)
168+
void VectorHandler::axpy(const real_type alpha,
169+
/* const */ vector::Vector* x,
170+
vector::Vector* y,
171+
memory::MemorySpace memspace)
169172
{
170173
// AXPY: y = alpha * x + y
171174
using namespace ReSolve::memory;
@@ -182,8 +185,16 @@ namespace ReSolve
182185
}
183186

184187
/**
185-
* @brief gemv computes matrix-vector product where both matrix and vectors are dense.
186-
* i.e., x = beta*x + alpha*V*y
188+
* @brief gemv computes dense matrix-vector product.
189+
*
190+
* In Re::Solve applications, gemv is used to compute dot products of
191+
* multivectors.
192+
*
193+
* If `transpose = N` (no), `x := beta*x + alpha*V*y`,
194+
* where `x` is `[n x 1]`, `V` is `[n x k]` and `y` is `[k x 1]`.
195+
* If `transpose = T` (yes), `x := beta*x + alpha*V^T*y`,
196+
* where `x` is `[k x 1]`, `V` is `[n x k]` and `y` is `[n x 1]`.
197+
*
187198
*
188199
* @param[in] Transpose - yes (T) or no (N)
189200
* @param[in] n - Number of rows in (non-transposed) matrix
@@ -198,26 +209,24 @@ namespace ReSolve
198209
* @pre V is stored colum-wise, _n_ > 0, _k_ > 0
199210
*
200211
*/
201-
void VectorHandler::gemv(char transpose,
202-
index_type n,
203-
index_type k,
204-
const real_type alpha,
205-
const real_type beta,
206-
vector::Vector* V,
207-
vector::Vector* y,
208-
vector::Vector* x,
212+
void VectorHandler::gemv(char transpose,
213+
index_type n,
214+
index_type k,
215+
const real_type alpha,
216+
const real_type beta,
217+
vector::Vector* V,
218+
vector::Vector* y,
219+
vector::Vector* x,
209220
memory::MemorySpace memspace)
210221
{
211222
using namespace ReSolve::memory;
212223

213224
// TODO: Remove n as the argument becuase it must always be n = V->getSize()
214225
assert(n == V->getSize() && "gemv: n does not match the number of rows in V");
215226

216-
assert((transpose == 'T' && V->getSize() == y->getSize()) &&
217-
"gemv: size mismatch! size of V^T does not match size of y");
218-
219-
assert((transpose == 'N' && V->getSize() == x->getSize()) &&
220-
"gemv: size mismatch! size of V does not match size of x");
227+
// TODO: These assertions should hold but they do not. Investigate more.
228+
// assert((transpose == 'T' && V->getSize() == y->getSize()) && "gemv: size mismatch! size of V^T does not match size of y");
229+
// assert(((transpose == 'N') && (V->getSize() == x->getSize())) && "gemv: size mismatch! size of V does not match size of x");
221230

222231
switch (memspace)
223232
{
@@ -232,19 +241,27 @@ namespace ReSolve
232241
}
233242

234243
/**
235-
* @brief mass (bulk) axpy i.e, y = y - x*alpha where alpha is a vector
244+
* @brief Multivector axpy: y: = y + \sum_i alpha_i x_i
236245
*
237246
* @param[in] size number of elements in y
238247
* @param[in] alpha vector size k x 1
239-
* @param[in] x (multi)vector size size x k
248+
* @param[in] x (multi)vector [size x k]
240249
* @param[in,out] y vector size size x 1 (this is where the result is stored)
241250
* @param[in] memspace string containg memspace (cpu or cuda or hip)
242251
*
243252
* @pre _k_ > 0, _size_ > 0, _size_ = x->getSize()
244253
*
245254
*/
246-
void VectorHandler::axpyMulti(index_type size, vector::Vector* alpha, index_type k, vector::Vector* x, vector::Vector* y, memory::MemorySpace memspace)
255+
void VectorHandler::axpyMulti(index_type size,
256+
vector::Vector* alpha,
257+
index_type k,
258+
vector::Vector* x,
259+
vector::Vector* y,
260+
memory::MemorySpace memspace)
247261
{
262+
assert(y->getSize() == x->getSize() && "Sizes of x and y must match!\n");
263+
assert(alpha->getSize() == k && "Size of alpha must match k!\n");
264+
248265
using namespace ReSolve::memory;
249266
switch (memspace)
250267
{
@@ -259,21 +276,32 @@ namespace ReSolve
259276
}
260277

261278
/**
262-
* @brief mass (bulk) dot product i.e, V^T x, where V is n x k dense multivector (a dense multivector consisting of k vectors size n)
263-
* and x is k x 2 dense multivector (a multivector consisiting of two vectors size n each)
279+
* @brief Multivector dot product, i.e V^T x
280+
*
281+
* Computes V^T x with k vectors from multivector V. Result is storred
282+
* in `res`.
264283
*
265284
* @param[in] size - Number of elements in a single vector in V
266-
* @param[in] V - Multivector; k vectors size n x 1 each
267-
* @param[in] k - Number of vectors in V
268-
* @param[in] x - Multivector; 2 vectors size n x 1 each
269-
* @param[out] res - Multivector; 2 vectors size k x 1 each (result is returned in res)
270-
* @param[in] memspace - String containg memspace (cpu or cuda or hip)
285+
* @param[in] V - Multivector; k vectors of size n x 1 each
286+
* @param[in] k - Number of vectors in V to use
287+
* @param[in] x - Multivector; 2 vectors of size n x 1 each
288+
* @param[out] res - Multivector; 2 vectors size k x 1 each
289+
* @param[in] memspace - String containg memspace (cpu, cuda or hip)
271290
*
272-
* @pre _size_ > 0, _k_ > 0, size = x->getSize(), _res_ needs to be allocated
291+
* @pre _size_ > 0, _k_ > 0, size = x->getSize().
292+
* @pre _res_ needs to be allocated to k x 2 size.
273293
*
274294
*/
275-
void VectorHandler::dot2Multi(index_type size, vector::Vector* V, index_type k, vector::Vector* x, vector::Vector* res, memory::MemorySpace memspace)
295+
void VectorHandler::dot2Multi(index_type size,
296+
vector::Vector* V,
297+
index_type k,
298+
vector::Vector* x,
299+
vector::Vector* res,
300+
memory::MemorySpace memspace)
276301
{
302+
assert(x->getSize() == V->getSize() && "Sizes of V and x do not match!\n");
303+
assert(res->getSize() == k && "Size of `res` must match k!\n");
304+
277305
using namespace ReSolve::memory;
278306
switch (memspace)
279307
{
@@ -304,6 +332,7 @@ namespace ReSolve
304332
assert(diag->getSize() == vec->getSize() && "Diagonal vector must be of the same size as the vector.");
305333
assert(diag->getData(memspace) != nullptr && "Diagonal vector data is null!\n");
306334
assert(vec->getData(memspace) != nullptr && "Vector data is null!\n");
335+
307336
using namespace ReSolve::memory;
308337
switch (memspace)
309338
{

resolve/vector/VectorHandler.hpp

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ namespace ReSolve
1616
} // namespace ReSolve
1717

1818
namespace ReSolve
19-
{ // namespace vector {
20-
19+
{
2120
class VectorHandler
2221
{
2322
public:
@@ -27,48 +26,49 @@ namespace ReSolve
2726
VectorHandler(LinAlgWorkspaceHIP* new_workspace);
2827
~VectorHandler();
2928

30-
// y = alpha x + y
31-
void axpy(const real_type alpha, /* const */ vector::Vector* x, vector::Vector* y, memory::MemorySpace memspace);
29+
// y := alpha x + y
30+
void axpy(const real_type alpha,
31+
/* const */ vector::Vector* x,
32+
vector::Vector* y,
33+
memory::MemorySpace memspace);
3234

33-
// dot: x \cdot y
35+
// Dot product of two vectors
3436
real_type dot(vector::Vector* x, vector::Vector* y, memory::MemorySpace memspace);
3537

36-
// scal = alpha * x
38+
// Scale vector by scalar
3739
void scal(const real_type alpha, vector::Vector* x, memory::MemorySpace memspace);
3840

39-
// mass axpy: x*alpha + y where x is [n x k] and alpha is [k x 1]; x is stored columnwise
40-
void axpyMulti(index_type size, vector::Vector* alpha, index_type k, vector::Vector* x, vector::Vector* y, memory::MemorySpace memspace);
41-
42-
// mass dot: V^T x, where V is [n x k] and x is [k x 2], everything is stored and returned columnwise
43-
// Size = n
44-
void dot2Multi(index_type size, vector::Vector* V, index_type k, vector::Vector* x, vector::Vector* res, memory::MemorySpace memspace);
41+
// Scale vector by diagonal matrix represented as a vector (i.e., vec = diag*vec)
42+
int scal(vector::Vector* diag, vector::Vector* vec, memory::MemorySpace memspace);
4543

46-
/**
47-
* @brief Dense matrix-vector product.
48-
*
49-
* In Re::Solve applications, gemv is used to compute dot products of
50-
* multivectors.
51-
*
52-
* if `transpose = N` (no), `x := beta*x + alpha*V*y`,
53-
* where `x` is `[n x 1]`, `V` is `[n x k]` and `y` is `[k x 1]`.
54-
* if `transpose = T` (yes), `x := beta*x + alpha*V^T*y`,
55-
* where `x` is `[k x 1]`, `V` is `[n x k]` and `y` is `[n x 1]`.
56-
*/
57-
void gemv(char transpose,
58-
index_type n,
59-
index_type k, // how many vectors from multivector V to use
60-
const real_type alpha,
61-
const real_type beta,
62-
vector::Vector* V,
63-
vector::Vector* y,
64-
vector::Vector* x,
44+
// axpy for multivectors
45+
void axpyMulti(index_type size,
46+
vector::Vector* alpha,
47+
index_type k,
48+
vector::Vector* x,
49+
vector::Vector* y,
50+
memory::MemorySpace memspace);
51+
52+
// Dot product of two vectors with a multivector V
53+
void dot2Multi(index_type size,
54+
vector::Vector* V,
55+
index_type k,
56+
vector::Vector* x,
57+
vector::Vector* res,
58+
memory::MemorySpace memspace);
59+
60+
// Dense matrix-vector product.
61+
void gemv(char transpose,
62+
index_type n,
63+
index_type k, // number of vectors from multivector V to use
64+
const real_type alpha,
65+
const real_type beta,
66+
vector::Vector* V,
67+
vector::Vector* y,
68+
vector::Vector* x,
6569
memory::MemorySpace memspace);
6670

67-
int scal(vector::Vector* diag, vector::Vector* vec, memory::MemorySpace memspace);
68-
69-
/** iamax:
70-
* Returns infinity norm of a vector (i.e., entry with max abs value)
71-
*/
71+
// Vector infinity norm
7272
real_type iamax(vector::Vector* x, memory::MemorySpace memspace);
7373

7474
bool getIsCudaEnabled() const;
@@ -81,6 +81,6 @@ namespace ReSolve
8181
bool isCpuEnabled_{false};
8282
bool isCudaEnabled_{false};
8383
bool isHipEnabled_{false};
84-
};
84+
}; // class VectorHandler
8585

8686
} // namespace ReSolve

0 commit comments

Comments
 (0)