Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions source/module_base/module_device/memory_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,5 +400,70 @@ template struct delete_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU
template struct delete_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
#endif

template <typename FPTYPE>
void resize_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice){
resize_memory_op<FPTYPE, base_device::DEVICE_CPU>()(cpu_ctx, arr);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
resize_memory_op<FPTYPE, base_device::DEVICE_GPU>()(gpu_ctx, arr);
}
}

template <typename FPTYPE>
void set_memory(FPTYPE* arr, const int var, const size_t size, base_device::AbacusDevice_t device_type){
if (device_type == base_device::AbacusDevice_t::CpuDevice){
set_memory_op<FPTYPE, base_device::DEVICE_CPU>()(cpu_ctx, arr, var, size);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
set_memory_op<FPTYPE, base_device::DEVICE_GPU>()(gpu_ctx, arr, var, size);
}
}

template <typename FPTYPE>
void synchronize_memory(FPTYPE* arr_out, const FPTYPE* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in){
if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){
synchronize_memory_op<FPTYPE, DEVICE_CPU, DEVICE_CPU>()(cpu_ctx, cpu_ctx, arr_out, arr_in, size);
}
else if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){
synchronize_memory_op<FPTYPE, DEVICE_CPU, DEVICE_GPU>()(cpu_ctx, gpu_ctx, arr_out, arr_in, size);
}
else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){
synchronize_memory_op<FPTYPE, DEVICE_GPU, DEVICE_CPU>()(gpu_ctx, cpu_ctx, arr_out, arr_in, size);
}
else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){
synchronize_memory_op<FPTYPE, DEVICE_GPU, DEVICE_GPU>()(gpu_ctx, gpu_ctx, arr_out, arr_in, size);
}
}

template <typename FPTYPE_out, typename FPTYPE_in>
void cast_memory(FPTYPE_out* arr_out, const FPTYPE_in* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in)
{
if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){
cast_memory_op<FPTYPE_out, FPTYPE_in, DEVICE_CPU, DEVICE_CPU>()(cpu_ctx, cpu_ctx, arr_out, arr_in, size);
}
else if (device_type_out == base_device::AbacusDevice_t::CpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){
cast_memory_op<FPTYPE_out, FPTYPE_in, DEVICE_CPU, DEVICE_GPU>()(cpu_ctx, gpu_ctx, arr_out, arr_in, size);
}
else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::CpuDevice){
cast_memory_op<FPTYPE_out, FPTYPE_in, DEVICE_GPU, DEVICE_CPU>()(gpu_ctx, cpu_ctx, arr_out, arr_in, size);
}
else if (device_type_out == base_device::AbacusDevice_t::GpuDevice || device_type_in == base_device::AbacusDevice_t::GpuDevice){
cast_memory_op<FPTYPE_out, FPTYPE_in, DEVICE_GPU, DEVICE_GPU>()(gpu_ctx, gpu_ctx, arr_out, arr_in, size);
}
}

template <typename FPTYPE>
void delete_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice){
delete_memory_op<FPTYPE, DEVICE_CPU>()(cpu_ctx, arr);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
delete_memory_op<FPTYPE, DEVICE_GPU>()(gpu_ctx, arr);
}
}

} // namespace memory
} // namespace base_device
14 changes: 14 additions & 0 deletions source/module_base/module_device/memory_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ struct delete_memory_op
void operator()(const Device* dev, FPTYPE* arr);
};

template <typename FPTYPE>
void resize_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

template <typename FPTYPE>
void set_memory(FPTYPE* arr, const int var, const size_t size, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

template <typename FPTYPE>
void synchronize_memory(FPTYPE* arr_out, const FPTYPE* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in);

template <typename FPTYPE_out, typename FPTYPE_in>
void cast_memory(FPTYPE_out* arr_out, const FPTYPE_in* arr_in, const size_t size, base_device::AbacusDevice_t device_type_out, base_device::AbacusDevice_t device_type_in);

template <typename FPTYPE>
void delete_memory(FPTYPE* arr, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
// Partially specialize operator for base_device::GpuDevice.
Expand Down
Loading