Skip to content

Commit 4d451dc

Browse files
committed
fix data transfer
1 parent 2a0c827 commit 4d451dc

File tree

3 files changed

+57
-17
lines changed

3 files changed

+57
-17
lines changed

onnxruntime/core/providers/webgpu/data_transfer.cc

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,45 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev
1313
(dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU);
1414
}
1515

16-
common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
17-
size_t bytes = src.SizeInBytes();
16+
common::Status DataTransfer::CopyTensorImpl(void const* src_data,
17+
bool src_is_gpu,
18+
void* dst_data,
19+
bool dst_is_gpu,
20+
size_t bytes) const {
1821
if (bytes > 0) {
19-
void const* src_data = src.DataRaw();
20-
void* dst_data = dst.MutableDataRaw();
21-
22-
auto& src_device = src.Location().device;
23-
auto& dst_device = dst.Location().device;
24-
25-
if (dst_device.Type() == OrtDevice::GPU) {
26-
if (src_device.Type() == OrtDevice::GPU) {
22+
if (dst_is_gpu) {
23+
if (src_is_gpu) {
2724
// copy from GPU to GPU
2825
buffer_manager_.MemCpy(static_cast<WGPUBuffer>(const_cast<void*>(src_data)),
29-
static_cast<WGPUBuffer>(dst_data), bytes);
26+
static_cast<WGPUBuffer>(dst_data),
27+
bytes);
3028
} else {
3129
// copy from CPU to GPU
32-
buffer_manager_.Upload(const_cast<void*>(src_data), static_cast<WGPUBuffer>(dst_data), bytes);
30+
buffer_manager_.Upload(const_cast<void*>(src_data),
31+
static_cast<WGPUBuffer>(dst_data),
32+
bytes);
3333
}
34-
} else /* if (src_device.Type() == OrtDevice::GPU) */ {
34+
} else {
3535
// copy from GPU to CPU
36-
buffer_manager_.Download(static_cast<WGPUBuffer>(const_cast<void*>(src_data)), dst_data, bytes);
36+
buffer_manager_.Download(static_cast<WGPUBuffer>(const_cast<void*>(src_data)),
37+
dst_data,
38+
bytes);
3739
}
3840
}
3941

4042
return Status::OK();
4143
}
4244

45+
common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
46+
void const* src_data = src.DataRaw();
47+
void* dst_data = dst.MutableDataRaw();
48+
49+
return CopyTensorImpl(src_data,
50+
src.Location().device.Type() == OrtDevice::GPU,
51+
dst_data,
52+
dst.Location().device.Type() == OrtDevice::GPU,
53+
src.SizeInBytes());
54+
}
55+
4356
} // namespace webgpu
4457
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/data_transfer.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class DataTransfer : public IDataTransfer {
2020

2121
common::Status CopyTensor(const Tensor& src, Tensor& dst) const override;
2222

23+
common::Status CopyTensorImpl(void const* src_data,
24+
bool src_is_gpu,
25+
void* dst_data,
26+
bool dst_is_gpu,
27+
size_t bytes) const;
28+
2329
private:
2430
const BufferManager& buffer_manager_;
2531
};

onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,30 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl {
346346

347347
// Now perform the actual tensor copy
348348
for (size_t idx = 0; idx < num_tensors; ++idx) {
349-
const OrtValue* src_tensor = src_tensors[idx];
350-
OrtValue* dst_tensor = dst_tensors[idx];
351-
auto status = impl.data_transfer_->CopyTensor(src_tensor->Get<Tensor>(), *dst_tensor->GetMutable<Tensor>());
349+
#if defined(BUILD_WEBGPU_EP_STATIC_LIB)
350+
const Tensor& src_tensor = src_tensors[idx]->Get<Tensor>();
351+
const void* src_data = src_tensor.DataRaw();
352+
size_t size = src_tensor.SizeInBytes();
353+
bool src_is_gpu = src_tensor.Location().device.Type() == OrtDevice::GPU;
354+
355+
Tensor& dst_tensor = *dst_tensors[idx]->GetMutable<Tensor>();
356+
void* dst_data = dst_tensor.MutableDataRaw();
357+
bool dst_is_gpu = dst_tensor.Location().device.Type() == OrtDevice::GPU;
358+
#else
359+
Ort::ConstValue src_value{src_tensors[idx]};
360+
const void* src_data = src_value.GetTensorRawData();
361+
size_t size = src_value.GetTensorSizeInBytes();
362+
bool src_is_gpu = src_value.GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU;
363+
364+
Ort::UnownedValue dst_value{dst_tensors[idx]};
365+
void* dst_data = dst_value.GetTensorMutableRawData();
366+
bool dst_is_gpu = dst_value.GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU;
367+
#endif
368+
auto status = impl.data_transfer_->CopyTensorImpl(src_data,
369+
src_is_gpu,
370+
dst_data,
371+
dst_is_gpu,
372+
size);
352373
if (!status.IsOK()) {
353374
return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str());
354375
}

0 commit comments

Comments
 (0)