Skip to content

Commit fe37372

Browse files
yuslepukhinCopilot
andauthored
[Lora] Adjust device dispatch according to the new OrtDevice defs (#26551)
### Description Check the device type and vendor to obtain data transfer. ### Motivation and Context Lora obtains DataTransfer based on the OrtMemoryInfo name which is now arbitrary. We should now rely on memory type and vendor definitions. --------- Co-authored-by: Copilot <[email protected]>
1 parent 1936d64 commit fe37372

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

onnxruntime/core/session/lora_adapters.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) {
5353
static std::unique_ptr<IDataTransfer> GetDataTransfer(const OrtMemoryInfo& mem_info) {
5454
std::unique_ptr<IDataTransfer> data_transfer;
5555

56-
if (mem_info.name == onnxruntime::CPU) {
56+
if (mem_info.device.Type() == OrtDevice::CPU) {
5757
return data_transfer;
5858
}
5959

60-
if (mem_info.name == onnxruntime::CUDA) {
60+
if (mem_info.device.Type() == OrtDevice::GPU && mem_info.device.Vendor() == OrtDevice::VendorIds::NVIDIA) {
6161
#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)
6262
auto* cuda_provider_info = TryGetProviderInfo_CUDA();
6363
if (cuda_provider_info != nullptr) {

onnxruntime/test/lora/lora_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) {
216216
for (; begin != end; ++begin) {
217217
const auto& [_, param] = *begin;
218218
const auto& tensor_device = param.GetDeviceOrMapped().Get<Tensor>();
219-
ASSERT_EQ(0, strcmp(tensor_device.Location().name.c_str(), onnxruntime::CUDA));
219+
const auto& mem_info = tensor_device.Location();
220+
ASSERT_EQ(mem_info.device.Type(), OrtDevice::GPU);
221+
ASSERT_EQ(mem_info.device.Vendor(), OrtDevice::VendorIds::NVIDIA);
220222

221223
const auto& tensor_cpu = param.GetMapped().Get<Tensor>();
222224
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());

0 commit comments

Comments
 (0)