Skip to content

Commit 5dcbaa6

Browse files
committed
Modify op usage
1 parent a6668dd commit 5dcbaa6

File tree

3 files changed

+62
-13
lines changed

3 files changed

+62
-13
lines changed

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,12 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
262262
}
263263
}
264264

265-
gemm_op<T, Device>()(this->ctx,
265+
#ifdef __DSP
266+
gemm_op_mt<T, Device>()
267+
#else
268+
gemm_op<T, Device>()
269+
#endif
270+
(this->ctx,
266271
'N',
267272
'N',
268273
this->dim,

source/module_hsolver/kernels/math_kernel_op.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -275,22 +275,36 @@ struct gemm_op<T, base_device::DEVICE_CPU>
275275
const int& ldb,
276276
const T* beta,
277277
T* c,
278-
const int& ldc,
279-
bool use_dsp)
278+
const int& ldc,)
280279
{
281-
#ifdef __DSP
282-
if (use_dsp){
283-
BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc, base_device::AbacusDevice_t::DspDevice);
284-
}
285-
else{
286-
BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc);
287-
}
288-
#else
289280
BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc);
290-
#endif
291281
}
292282
};
293283

284+
#ifdef __DSP
285+
template <typename T>
286+
struct gemm_op_mt<T, base_device::DEVICE_CPU>
287+
{
288+
void operator()(const base_device::DEVICE_CPU* /*ctx*/,
289+
const char& transa,
290+
const char& transb,
291+
const int& m,
292+
const int& n,
293+
const int& k,
294+
const T* alpha,
295+
const T* a,
296+
const int& lda,
297+
const T* b,
298+
const int& ldb,
299+
const T* beta,
300+
T* c,
301+
const int& ldc)
302+
{
303+
BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc, base_device::AbacusDevice_t::DspDevice);
304+
}
305+
};
306+
#endif
307+
294308
template <typename T>
295309
struct matrixTranspose_op<T, base_device::DEVICE_CPU>
296310
{

source/module_hsolver/kernels/math_kernel_op.h

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,39 @@ template <typename T, typename Device> struct gemm_op {
261261
void operator()(const Device *d, const char &transa, const char &transb,
262262
const int &m, const int &n, const int &k, const T *alpha,
263263
const T *a, const int &lda, const T *b, const int &ldb,
264-
const T *beta, T *c, const int &ldc, bool usd_dsp = false);
264+
const T *beta, T *c, const int &ldc);
265265
};
266266

267+
#ifdef __DSP
268+
// compute C = alpha * op(A) * op(B) + beta * C on DSP Hardware
269+
template <typename T, typename Device> struct gemm_op_mt {
270+
/// @brief C = alpha * op(A) * op(B) + beta * C
271+
///
272+
/// Input Parameters
273+
/// \param d : the type of computing device
274+
/// \param transa : whether to transpose matrix A
275+
/// \param transb : whether to transpose matrix B
276+
/// \param m : first dimension of matrix mulplication
277+
/// \param n : second dimension of matrix mulplication
278+
/// \param k : third dimension of matrix mulplication
279+
/// \param alpha : input constant alpha
280+
/// \param a : input matrix A
281+
/// \param lda : leading dimention of A
282+
/// \param b : input matrix B
283+
/// \param ldb : leading dimention of A
284+
/// \param beta : input constant beta
285+
/// \param c : input matrix C
286+
/// \param ldc : leading dimention of C
287+
///
288+
/// Output Parameters
289+
/// \param c : output matrix C
290+
void operator()(const Device *d, const char &transa, const char &transb,
291+
const int &m, const int &n, const int &k, const T *alpha,
292+
const T *a, const int &lda, const T *b, const int &ldb,
293+
const T *beta, T *c, const int &ldc);
294+
};
295+
#endif
296+
267297
template <typename T, typename Device> struct matrixTranspose_op {
268298
/// @brief transpose the input matrix
269299
///

0 commit comments

Comments
 (0)