@@ -13,32 +13,45 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev
1313 (dst_device.Type () == OrtDevice::CPU && src_device.Type () == OrtDevice::GPU);
1414}
1515
16- common::Status DataTransfer::CopyTensor (const Tensor& src, Tensor& dst) const {
17- size_t bytes = src.SizeInBytes ();
16+ common::Status DataTransfer::CopyTensorImpl (void const * src_data,
17+ bool src_is_gpu,
18+ void * dst_data,
19+ bool dst_is_gpu,
20+ size_t bytes) const {
1821 if (bytes > 0 ) {
19- void const * src_data = src.DataRaw ();
20- void * dst_data = dst.MutableDataRaw ();
21-
22- auto & src_device = src.Location ().device ;
23- auto & dst_device = dst.Location ().device ;
24-
25- if (dst_device.Type () == OrtDevice::GPU) {
26- if (src_device.Type () == OrtDevice::GPU) {
22+ if (dst_is_gpu) {
23+ if (src_is_gpu) {
2724 // copy from GPU to GPU
2825 buffer_manager_.MemCpy (static_cast <WGPUBuffer>(const_cast <void *>(src_data)),
29- static_cast <WGPUBuffer>(dst_data), bytes);
26+ static_cast <WGPUBuffer>(dst_data),
27+ bytes);
3028 } else {
3129 // copy from CPU to GPU
32- buffer_manager_.Upload (const_cast <void *>(src_data), static_cast <WGPUBuffer>(dst_data), bytes);
30+ buffer_manager_.Upload (const_cast <void *>(src_data),
31+ static_cast <WGPUBuffer>(dst_data),
32+ bytes);
3333 }
34- } else /* if (src_device.Type() == OrtDevice::GPU) */ {
34+ } else {
3535 // copy from GPU to CPU
36- buffer_manager_.Download (static_cast <WGPUBuffer>(const_cast <void *>(src_data)), dst_data, bytes);
36+ buffer_manager_.Download (static_cast <WGPUBuffer>(const_cast <void *>(src_data)),
37+ dst_data,
38+ bytes);
3739 }
3840 }
3941
4042 return Status::OK ();
4143}
4244
45+ common::Status DataTransfer::CopyTensor (const Tensor& src, Tensor& dst) const {
46+ void const * src_data = src.DataRaw ();
47+ void * dst_data = dst.MutableDataRaw ();
48+
49+ return CopyTensorImpl (src_data,
50+ src.Location ().device .Type () == OrtDevice::GPU,
51+ dst_data,
52+ dst.Location ().device .Type () == OrtDevice::GPU,
53+ src.SizeInBytes ());
54+ }
55+
4356} // namespace webgpu
4457} // namespace onnxruntime
0 commit comments