Skip to content

Commit 478ed3f

Browse files
committed
Optimize memory op usage
1 parent 348cfb8 commit 478ed3f

File tree

3 files changed

+37
-248
lines changed

3 files changed

+37
-248
lines changed

source/module_base/module_device/memory_op.cpp

Lines changed: 12 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ template struct delete_memory_op<std::complex<double>, base_device::DEVICE_GPU>;
349349
#ifdef __DSP
350350

351351
template <typename FPTYPE>
352-
struct resize_memory_op<FPTYPE, base_device::DEVICE_DSP>
352+
struct resize_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
353353
{
354354
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE*& arr, const size_t size, const char* record_in)
355355
{
@@ -376,116 +376,7 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_DSP>
376376
};
377377

378378
template <typename FPTYPE>
379-
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_DSP, base_device::DEVICE_CPU>
380-
{
381-
void operator()(const base_device::DEVICE_CPU* dev_out,
382-
const base_device::DEVICE_CPU* dev_in,
383-
FPTYPE* arr_out,
384-
const FPTYPE* arr_in,
385-
const size_t size)
386-
{
387-
ModuleBase::OMP_PARALLEL([&](int num_thread, int thread_id) {
388-
int beg = 0, len = 0;
389-
ModuleBase::BLOCK_TASK_DIST_1D(num_thread, thread_id, size, (size_t)4096 / sizeof(FPTYPE), beg, len);
390-
memcpy(arr_out + beg, arr_in + beg, sizeof(FPTYPE) * len);
391-
});
392-
}
393-
};
394-
395-
template <typename FPTYPE>
396-
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_DSP>
397-
{
398-
void operator()(const base_device::DEVICE_CPU* dev_out,
399-
const base_device::DEVICE_CPU* dev_in,
400-
FPTYPE* arr_out,
401-
const FPTYPE* arr_in,
402-
const size_t size)
403-
{
404-
ModuleBase::OMP_PARALLEL([&](int num_thread, int thread_id) {
405-
int beg = 0, len = 0;
406-
ModuleBase::BLOCK_TASK_DIST_1D(num_thread, thread_id, size, (size_t)4096 / sizeof(FPTYPE), beg, len);
407-
memcpy(arr_out + beg, arr_in + beg, sizeof(FPTYPE) * len);
408-
});
409-
}
410-
};
411-
412-
template <typename FPTYPE>
413-
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_DSP, base_device::DEVICE_DSP>
414-
{
415-
void operator()(const base_device::DEVICE_CPU* dev_out,
416-
const base_device::DEVICE_CPU* dev_in,
417-
FPTYPE* arr_out,
418-
const FPTYPE* arr_in,
419-
const size_t size)
420-
{
421-
ModuleBase::OMP_PARALLEL([&](int num_thread, int thread_id) {
422-
int beg = 0, len = 0;
423-
ModuleBase::BLOCK_TASK_DIST_1D(num_thread, thread_id, size, (size_t)4096 / sizeof(FPTYPE), beg, len);
424-
memcpy(arr_out + beg, arr_in + beg, sizeof(FPTYPE) * len);
425-
});
426-
}
427-
};
428-
429-
430-
template <typename FPTYPE_out, typename FPTYPE_in>
431-
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_CPU, base_device::DEVICE_DSP>
432-
{
433-
void operator()(const base_device::DEVICE_CPU* dev_out,
434-
const base_device::DEVICE_CPU* dev_in,
435-
FPTYPE_out* arr_out,
436-
const FPTYPE_in* arr_in,
437-
const size_t size)
438-
{
439-
#ifdef _OPENMP
440-
#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE_out))
441-
#endif
442-
for (int ii = 0; ii < size; ii++)
443-
{
444-
arr_out[ii] = static_cast<FPTYPE_out>(arr_in[ii]);
445-
}
446-
}
447-
};
448-
449-
template <typename FPTYPE_out, typename FPTYPE_in>
450-
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_DSP, base_device::DEVICE_DSP>
451-
{
452-
void operator()(const base_device::DEVICE_CPU* dev_out,
453-
const base_device::DEVICE_CPU* dev_in,
454-
FPTYPE_out* arr_out,
455-
const FPTYPE_in* arr_in,
456-
const size_t size)
457-
{
458-
#ifdef _OPENMP
459-
#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE_out))
460-
#endif
461-
for (int ii = 0; ii < size; ii++)
462-
{
463-
arr_out[ii] = static_cast<FPTYPE_out>(arr_in[ii]);
464-
}
465-
}
466-
};
467-
468-
template <typename FPTYPE_out, typename FPTYPE_in>
469-
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_DSP, base_device::DEVICE_CPU>
470-
{
471-
void operator()(const base_device::DEVICE_CPU* dev_out,
472-
const base_device::DEVICE_CPU* dev_in,
473-
FPTYPE_out* arr_out,
474-
const FPTYPE_in* arr_in,
475-
const size_t size)
476-
{
477-
#ifdef _OPENMP
478-
#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE_out))
479-
#endif
480-
for (int ii = 0; ii < size; ii++)
481-
{
482-
arr_out[ii] = static_cast<FPTYPE_out>(arr_in[ii]);
483-
}
484-
}
485-
};
486-
487-
template <typename FPTYPE>
488-
struct delete_memory_op<FPTYPE, base_device::DEVICE_DSP>
379+
struct delete_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
489380
{
490381
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
491382
{
@@ -494,94 +385,17 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_DSP>
494385
};
495386

496387

497-
template struct resize_memory_op<int, base_device::DEVICE_DSP>;
498-
template struct resize_memory_op<float, base_device::DEVICE_DSP>;
499-
template struct resize_memory_op<double, base_device::DEVICE_DSP>;
500-
template struct resize_memory_op<std::complex<float>, base_device::DEVICE_DSP>;
501-
template struct resize_memory_op<std::complex<double>, base_device::DEVICE_DSP>;
502-
503-
template struct synchronize_memory_op<int, base_device::DEVICE_CPU, base_device::DEVICE_DSP>;
504-
template struct synchronize_memory_op<int, base_device::DEVICE_DSP, base_device::DEVICE_CPU>;
505-
template struct synchronize_memory_op<int, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
506-
template struct synchronize_memory_op<float, base_device::DEVICE_CPU, base_device::DEVICE_DSP>;
507-
template struct synchronize_memory_op<float, base_device::DEVICE_DSP, base_device::DEVICE_CPU>;
508-
template struct synchronize_memory_op<float, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
509-
template struct synchronize_memory_op<double, base_device::DEVICE_CPU, base_device::DEVICE_DSP>;
510-
template struct synchronize_memory_op<double, base_device::DEVICE_DSP, base_device::DEVICE_CPU>;
511-
template struct synchronize_memory_op<double, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
512-
template struct synchronize_memory_op<std::complex<float>, base_device::DEVICE_CPU, base_device::DEVICE_DSP>;
513-
template struct synchronize_memory_op<std::complex<float>, base_device::DEVICE_DSP, base_device::DEVICE_CPU>;
514-
template struct synchronize_memory_op<std::complex<float>, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
515-
template struct synchronize_memory_op<std::complex<double>, base_device::DEVICE_CPU, base_device::DEVICE_DSP>;
516-
template struct synchronize_memory_op<std::complex<double>, base_device::DEVICE_DSP, base_device::DEVICE_CPU>;
517-
template struct synchronize_memory_op<std::complex<double>, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
518-
519-
template struct cast_memory_op<float, float, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
520-
template struct cast_memory_op<double, double, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
521-
template struct cast_memory_op<float, double, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
522-
template struct cast_memory_op<double, float, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
523-
template struct cast_memory_op<std::complex<float>,
524-
std::complex<float>,
525-
base_device::DEVICE_DSP,
526-
base_device::DEVICE_DSP>;
527-
template struct cast_memory_op<std::complex<double>,
528-
std::complex<double>,
529-
base_device::DEVICE_DSP,
530-
base_device::DEVICE_DSP>;
531-
template struct cast_memory_op<std::complex<float>,
532-
std::complex<double>,
533-
base_device::DEVICE_DSP,
534-
base_device::DEVICE_DSP>;
535-
template struct cast_memory_op<std::complex<double>,
536-
std::complex<float>,
537-
base_device::DEVICE_DSP,
538-
base_device::DEVICE_DSP>;
539-
template struct cast_memory_op<float, float, base_device::DEVICE_DSP, base_device::DEVICE_CPU>;
540-
template struct cast_memory_op<double, double, base_device::DEVICE_DSP, base_device::DEVICE_CPU>;
541-
template struct cast_memory_op<float, double, base_device::DEVICE_DSP, base_device::DEVICE_CPU>;
542-
template struct cast_memory_op<double, float, base_device::DEVICE_DSP, base_device::DEVICE_CPU>;
543-
template struct cast_memory_op<std::complex<float>,
544-
std::complex<float>,
545-
base_device::DEVICE_DSP,
546-
base_device::DEVICE_CPU>;
547-
template struct cast_memory_op<std::complex<double>,
548-
std::complex<double>,
549-
base_device::DEVICE_DSP,
550-
base_device::DEVICE_CPU>;
551-
template struct cast_memory_op<std::complex<float>,
552-
std::complex<double>,
553-
base_device::DEVICE_DSP,
554-
base_device::DEVICE_CPU>;
555-
template struct cast_memory_op<std::complex<double>,
556-
std::complex<float>,
557-
base_device::DEVICE_DSP,
558-
base_device::DEVICE_CPU>;
559-
template struct cast_memory_op<float, float, base_device::DEVICE_CPU, base_device::DEVICE_DSP>;
560-
template struct cast_memory_op<double, double, base_device::DEVICE_CPU, base_device::DEVICE_DSP>;
561-
template struct cast_memory_op<float, double, base_device::DEVICE_CPU, base_device::DEVICE_DSP>;
562-
template struct cast_memory_op<double, float, base_device::DEVICE_CPU, base_device::DEVICE_DSP>;
563-
template struct cast_memory_op<std::complex<float>,
564-
std::complex<float>,
565-
base_device::DEVICE_CPU,
566-
base_device::DEVICE_DSP>;
567-
template struct cast_memory_op<std::complex<double>,
568-
std::complex<double>,
569-
base_device::DEVICE_CPU,
570-
base_device::DEVICE_DSP>;
571-
template struct cast_memory_op<std::complex<float>,
572-
std::complex<double>,
573-
base_device::DEVICE_CPU,
574-
base_device::DEVICE_DSP>;
575-
template struct cast_memory_op<std::complex<double>,
576-
std::complex<float>,
577-
base_device::DEVICE_CPU,
578-
base_device::DEVICE_DSP>;
388+
template struct resize_memory_op_mt<int, base_device::DEVICE_CPU>;
389+
template struct resize_memory_op_mt<float, base_device::DEVICE_CPU>;
390+
template struct resize_memory_op_mt<double, base_device::DEVICE_CPU>;
391+
template struct resize_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
392+
template struct resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
579393

580-
template struct delete_memory_op<int, base_device::DEVICE_DSP>;
581-
template struct delete_memory_op<float, base_device::DEVICE_DSP>;
582-
template struct delete_memory_op<double, base_device::DEVICE_DSP>;
583-
template struct delete_memory_op<std::complex<float>, base_device::DEVICE_DSP>;
584-
template struct delete_memory_op<std::complex<double>, base_device::DEVICE_DSP>;
394+
template struct delete_memory_op_mt<int, base_device::DEVICE_CPU>;
395+
template struct delete_memory_op_mt<float, base_device::DEVICE_CPU>;
396+
template struct delete_memory_op_mt<double, base_device::DEVICE_CPU>;
397+
template struct delete_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
398+
template struct delete_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
585399
#endif
586400

587401
} // namespace memory

source/module_base/module_device/memory_op.h

Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -147,55 +147,33 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_GPU>
147147
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
148148

149149
#ifdef __DSP
150-
// Partially specialize operator for base_device::GpuDevice.
151-
template <typename FPTYPE>
152-
struct resize_memory_op<FPTYPE, base_device::DEVICE_DSP>
153-
{
154-
void operator()(const base_device::DEVICE_CPU* dev,
155-
FPTYPE*& arr,
156-
const size_t size,
157-
const char* record_in = nullptr);
158-
};
159150

160-
template <typename FPTYPE>
161-
struct set_memory_op<FPTYPE, base_device::DEVICE_DSP>
151+
template <typename FPTYPE, typename Device>
152+
struct resize_memory_op_mt
162153
{
163-
void operator()(const base_device::DEVICE_GPU* dev, FPTYPE* arr, const int var, const size_t size);
154+
/// @brief Allocate memory for a given pointer. Note this op will free the pointer first.
155+
///
156+
/// Input Parameters
157+
/// \param dev : the type of computing device
158+
/// \param size : array size
159+
/// \param record_string : label for memory record
160+
///
161+
/// Output Parameters
162+
/// \param arr : allocated array
163+
void operator()(const Device* dev, FPTYPE*& arr, const size_t size, const char* record_in = nullptr);
164164
};
165165

166-
template <typename FPTYPE>
167-
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_DSP>
168-
{
169-
void operator()(const base_device::DEVICE_CPU* dev_out,
170-
const base_device::DEVICE_CPU* dev_in,
171-
FPTYPE* arr_out,
172-
const FPTYPE* arr_in,
173-
const size_t size);
174-
};
175-
template <typename FPTYPE>
176-
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_DSP, base_device::DEVICE_CPU>
177-
{
178-
void operator()(const base_device::DEVICE_CPU* dev_out,
179-
const base_device::DEVICE_CPU* dev_in,
180-
FPTYPE* arr_out,
181-
const FPTYPE* arr_in,
182-
const size_t size);
183-
};
184-
template <typename FPTYPE>
185-
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_DSP, base_device::DEVICE_DSP>
166+
template <typename FPTYPE, typename Device>
167+
struct delete_memory_op_mt
186168
{
187-
void operator()(const base_device::DEVICE_CPU* dev_out,
188-
const base_device::DEVICE_CPU* dev_in,
189-
FPTYPE* arr_out,
190-
const FPTYPE* arr_in,
191-
const size_t size);
169+
/// @brief free memory for multi-device
170+
///
171+
/// Input Parameters
172+
/// \param dev : the type of computing device
173+
/// \param arr : the input array
174+
void operator()(const Device* dev, FPTYPE* arr);
192175
};
193176

194-
template <typename FPTYPE>
195-
struct delete_memory_op<FPTYPE, base_device::DEVICE_DSP>
196-
{
197-
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr);
198-
};
199177
#endif // __DSP
200178

201179
} // end of namespace memory
@@ -285,6 +263,4 @@ using castmem_z2c_d2h_op = base_device::memory::
285263

286264
static base_device::DEVICE_CPU* cpu_ctx = {};
287265
static base_device::DEVICE_GPU* gpu_ctx = {};
288-
static base_device::DEVICE_DSP* gpu_ctx = {};
289-
290266
#endif // MODULE_DEVICE_MEMORY_H_

source/module_psi/psi.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,15 @@ class Psi
144144
bool allocate_inside = true; ///<whether allocate psi inside Psi class
145145

146146
#ifdef __DSP
147-
using set_memory_op = base_device::memory::set_memory_op<T, base_device::DEVICE_DSP>;
148-
using delete_memory_op = base_device::memory::delete_memory_op<T, base_device::DEVICE_DSP>;
149-
using resize_memory_op = base_device::memory::resize_memory_op<T, base_device::DEVICE_DSP>;
150-
using synchronize_memory_op = base_device::memory::synchronize_memory_op<T, base_device::DEVICE_DSP, base_device::DEVICE_DSP>;
147+
using delete_memory_op = base_device::memory::delete_memory_op_mt<T, Device>;
148+
using resize_memory_op = base_device::memory::resize_memory_op_mt<T, Device>;
151149
#else
152-
using set_memory_op = base_device::memory::set_memory_op<T, Device>;
153150
using delete_memory_op = base_device::memory::delete_memory_op<T, Device>;
154151
using resize_memory_op = base_device::memory::resize_memory_op<T, Device>;
155-
using synchronize_memory_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
156152
#endif
153+
using set_memory_op = base_device::memory::set_memory_op<T, Device>;
154+
using synchronize_memory_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
155+
157156
};
158157

159158
} // end of namespace psi

0 commit comments

Comments
 (0)