Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions source/source_base/kernels/dsp/dsp_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
extern "C"
{
#define complex_double ignore_complex_double
#include <mt_hthread_blas.h> // MTBLAS_TRANSPOSE etc
#include <mt_hthread_blas.h> // include faster mtblas kernels
#undef complex_double
#include <mtblas_interface.h> // gemm
#include <mtblas_interface.h> // include normal mtblas kernels that automatically operate memory, but slower.
}
namespace mtfunc
{
Expand All @@ -22,45 +22,42 @@ void dspDestoryHandle(int id)
{
hthread_dev_close(id);
std::cout << " ** DSP closed on cluster " << id << " **" << std::endl;
} // Close dsp cluster at the end
} // Close dsp cluster at the end of the program

MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans)
// MTBlas secretly removed its MTBLAS_TRANSPOSE data type and used the original CBLAS_TRANSPOSE. So this function is modified.

CBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans)
{
switch (blasTrans[0])
{
case 'N':
case 'n':
return MtblasNoTrans;
return CblasNoTrans;
case 'T':
case 't':
return MtblasTrans;
return CblasTrans;
case 'C':
case 'c':
return MtblasConjTrans;
return CblasConjTrans;
default:
std::cout << "Invalid BLAS transpose parameter!! Use default instead." << std::endl;
return MtblasNoTrans;
return CblasNoTrans;
}
} // Used to convert normal transpost char to mtblas transpose flag
} // Used to convert normal transpost char to cblas transpose flag

void* malloc_ht(size_t bytes, int cluster_id)
{
// std::cout << "MALLOC " << cluster_id;
void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW);
// std::cout << ptr << " SUCCEED" << std::endl;;
return ptr;
}
} // Malloc on dsp. Used to replace original malloc


// Used to replace original malloc

void free_ht(void* ptr)
{
// std::cout << "FREE " << ptr;
hthread_free(ptr);
// std::cout << " FREE SUCCEED" << std::endl;
}
} // Free on dsp. Used to replace original free

// Used to replace original free

void sgemm_mt_(const char* transa,
const char* transb,
Expand All @@ -77,7 +74,7 @@ void sgemm_mt_(const char* transa,
const int* ldc,
int cluster_id)
{
mtblas_sgemm(MTBLAS_ORDER::MtblasColMajor,
mtblas_sgemm(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
Expand Down Expand Up @@ -109,7 +106,7 @@ void dgemm_mt_(const char* transa,
const int* ldc,
int cluster_id)
{
mtblas_dgemm(MTBLAS_ORDER::MtblasColMajor,
mtblas_dgemm(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
Expand Down Expand Up @@ -141,7 +138,7 @@ void zgemm_mt_(const char* transa,
const int* ldc,
int cluster_id)
{
mtblas_zgemm(MTBLAS_ORDER::MtblasColMajor,
mtblas_zgemm(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
Expand Down Expand Up @@ -173,7 +170,7 @@ void cgemm_mt_(const char* transa,
const int* ldc,
int cluster_id)
{
mtblas_cgemm(MTBLAS_ORDER::MtblasColMajor,
mtblas_cgemm(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
Expand Down Expand Up @@ -207,7 +204,7 @@ void sgemm_mth_(const char* transa,
const int* ldc,
int cluster_id)
{
mt_hthread_sgemm(MTBLAS_ORDER::MtblasColMajor,
mt_hthread_sgemm(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
Expand Down Expand Up @@ -239,7 +236,7 @@ void dgemm_mth_(const char* transa,
const int* ldc,
int cluster_id)
{
mt_hthread_dgemm(MTBLAS_ORDER::MtblasColMajor,
mt_hthread_dgemm(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
Expand Down Expand Up @@ -275,7 +272,7 @@ void zgemm_mth_(const char* transa,
*alp = *alpha;
std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*bet = *beta;
mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor,
mt_hthread_zgemm(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
Expand Down Expand Up @@ -314,7 +311,7 @@ void cgemm_mth_(const char* transa,
std::complex<float>* bet = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
*bet = *beta;

mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor,
mt_hthread_cgemm(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
Expand Down
4 changes: 2 additions & 2 deletions source/source_base/module_fft/fft_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ void FFT_DSP<double>::resource_handler(const int flag) const
template <>
void FFT_DSP<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
{
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for);
hthread_group_exec(thread_id_for, "execute_mtfft_3d", 1, 1, args_for);
hthread_group_wait(thread_id_for);
}

template <>
void FFT_DSP<double>::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const
{
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back);
hthread_group_exec(thread_id_for, "execute_mtfft_3d", 1, 1, args_back);
hthread_group_wait(thread_id_for);
}
template <>
Expand Down
Loading