Skip to content

Commit 76d023b

Browse files
authored
Add memory functions using abacusDevice_t (#5861)
1 parent 0a0e19a commit 76d023b

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

source/module_base/module_device/memory_op.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,5 +400,70 @@ template struct delete_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU
400400
template struct delete_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
401401
#endif
402402

403+
template <typename FPTYPE>
404+
void resize_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type)
405+
{
406+
if (device_type == base_device::AbacusDevice_t::CpuDevice){
407+
resize_memory_op<FPTYPE, base_device::DEVICE_CPU>()(cpu_ctx, arr);
408+
}
409+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
410+
resize_memory_op<FPTYPE, base_device::DEVICE_GPU>()(gpu_ctx, arr);
411+
}
412+
}
413+
414+
template <typename FPTYPE>
415+
void set_memory(FPTYPE* arr, const int var, const size_t size, base_device::AbacusDevice_t device_type){
416+
if (device_type == base_device::AbacusDevice_t::CpuDevice){
417+
set_memory_op<FPTYPE, base_device::DEVICE_CPU>()(cpu_ctx, arr, var, size);
418+
}
419+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
420+
set_memory_op<FPTYPE, base_device::DEVICE_GPU>()(gpu_ctx, arr, var, size);
421+
}
422+
}
423+
424+
template <typename FPTYPE>
425+
void synchronize_memory(FPTYPE* arr_out, const FPTYPE* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in){
426+
if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){
427+
synchronize_memory_op<FPTYPE, DEVICE_CPU, DEVICE_CPU>()(cpu_ctx, cpu_ctx, arr_out, arr_in, size);
428+
}
429+
else if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){
430+
synchronize_memory_op<FPTYPE, DEVICE_CPU, DEVICE_GPU>()(cpu_ctx, gpu_ctx, arr_out, arr_in, size);
431+
}
432+
else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){
433+
synchronize_memory_op<FPTYPE, DEVICE_GPU, DEVICE_CPU>()(gpu_ctx, cpu_ctx, arr_out, arr_in, size);
434+
}
435+
else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){
436+
synchronize_memory_op<FPTYPE, DEVICE_GPU, DEVICE_GPU>()(gpu_ctx, gpu_ctx, arr_out, arr_in, size);
437+
}
438+
}
439+
440+
template <typename FPTYPE_out, typename FPTYPE_in>
441+
void cast_memory(FPTYPE_out* arr_out, const FPTYPE_in* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in)
442+
{
443+
if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){
444+
cast_memory_op<FPTYPE_out, FPTYPE_in, DEVICE_CPU, DEVICE_CPU>()(cpu_ctx, cpu_ctx, arr_out, arr_in, size);
445+
}
446+
else if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){
447+
cast_memory_op<FPTYPE_out, FPTYPE_in, DEVICE_CPU, DEVICE_GPU>()(cpu_ctx, gpu_ctx, arr_out, arr_in, size);
448+
}
449+
else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){
450+
cast_memory_op<FPTYPE_out, FPTYPE_in, DEVICE_GPU, DEVICE_CPU>()(gpu_ctx, cpu_ctx, arr_out, arr_in, size);
451+
}
452+
else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){
453+
cast_memory_op<FPTYPE_out, FPTYPE_in, DEVICE_GPU, DEVICE_GPU>()(gpu_ctx, gpu_ctx, arr_out, arr_in, size);
454+
}
455+
}
456+
457+
template <typename FPTYPE>
458+
void delete_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type)
459+
{
460+
if (device_type == base_device::AbacusDevice_t::CpuDevice){
461+
delete_memory_op<FPTYPE, DEVICE_CPU>()(cpu_ctx, arr);
462+
}
463+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
464+
delete_memory_op<FPTYPE, DEVICE_GPU>()(gpu_ctx, arr);
465+
}
466+
}
467+
403468
} // namespace memory
404469
} // namespace base_device

source/module_base/module_device/memory_op.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,20 @@ struct delete_memory_op
9393
void operator()(const Device* dev, FPTYPE* arr);
9494
};
9595

96+
template <typename FPTYPE>
97+
void resize_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
98+
99+
template <typename FPTYPE>
100+
void set_memory(FPTYPE* arr, const int var, const size_t size, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
101+
102+
template <typename FPTYPE>
103+
void synchronize_memory(FPTYPE* arr_out, const FPTYPE* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in);
104+
105+
template <typename FPTYPE_out, typename FPTYPE_in>
106+
void cast_memory(FPTYPE_out* arr_out, const FPTYPE_in* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in);
107+
108+
template <typename FPTYPE>
109+
void delete_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
96110

97111
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
98112
// Partially specialize operator for base_device::GpuDevice.

0 commit comments

Comments
 (0)