From a207cef1afa3905416a70bc80853e171c48e889e Mon Sep 17 00:00:00 2001 From: critsium-xy Date: Wed, 15 Jan 2025 13:59:04 +0800 Subject: [PATCH] Add memory functions using abacusDevice_t --- .../module_base/module_device/memory_op.cpp | 65 +++++++++++++++++++ source/module_base/module_device/memory_op.h | 14 ++++ 2 files changed, 79 insertions(+) diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 68146c275a..3c807dfad7 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -400,5 +400,70 @@ template struct delete_memory_op_mt, base_device::DEVICE_CPU template struct delete_memory_op_mt, base_device::DEVICE_CPU>; #endif +template +void resize_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice){ + resize_memory_op()(cpu_ctx, arr); + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ + resize_memory_op()(gpu_ctx, arr); + } +} + +template +void set_memory(FPTYPE* arr, const int var, const size_t size, base_device::AbacusDevice_t device_type){ + if (device_type == base_device::AbacusDevice_t::CpuDevice){ + set_memory_op()(cpu_ctx, arr, var, size); + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ + set_memory_op()(gpu_ctx, arr, var, size); + } +} + +template +void synchronize_memory(FPTYPE* arr_out, const FPTYPE* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in){ + if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){ + synchronize_memory_op()(cpu_ctx, cpu_ctx, arr_out, arr_in, size); + } + else if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){ + synchronize_memory_op()(cpu_ctx, gpu_ctx, arr_out, arr_in, size); + } + else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){ + synchronize_memory_op()(gpu_ctx, cpu_ctx, arr_out, arr_in, size); + } + else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){ + synchronize_memory_op()(gpu_ctx, gpu_ctx, arr_out, arr_in, size); + } +} + +template +void cast_memory(FPTYPE_out* arr_out, const FPTYPE_in* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in) +{ + if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){ + cast_memory_op()(cpu_ctx, cpu_ctx, arr_out, arr_in, size); + } + else if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){ + cast_memory_op()(cpu_ctx, gpu_ctx, arr_out, arr_in, size); + } + else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){ + cast_memory_op()(gpu_ctx, cpu_ctx, arr_out, arr_in, size); + } + else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){ + cast_memory_op()(gpu_ctx, gpu_ctx, arr_out, arr_in, size); + } +} + +template +void delete_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice){ + delete_memory_op()(cpu_ctx, arr); + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ + delete_memory_op()(gpu_ctx, arr); + } +} + } // namespace memory } // namespace base_device \ No newline at end of file diff --git a/source/module_base/module_device/memory_op.h b/source/module_base/module_device/memory_op.h index 49ca788d0a..14926caf9b 100644 --- a/source/module_base/module_device/memory_op.h +++ b/source/module_base/module_device/memory_op.h @@ -93,6 +93,20 @@ struct delete_memory_op void operator()(const Device* dev, FPTYPE* arr); }; +template +void resize_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + +template +void set_memory(FPTYPE* arr, const int var, const size_t size, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + +template +void synchronize_memory(FPTYPE* arr_out, const FPTYPE* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in); + +template +void cast_memory(FPTYPE_out* arr_out, const FPTYPE_in* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in); + +template +void delete_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM // Partially specialize operator for base_device::GpuDevice.