@@ -3383,17 +3383,58 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType
3383
3383
3384
3384
for (const auto * def : def_list) {
3385
3385
InlinedVector<SessionState::NodeInfo> node_info_vec;
3386
+ Status status;
3386
3387
if (type == SessionInputOutputType::kOutput ) {
3387
- ORT_RETURN_IF_ERROR ( session_state_->GetOutputNodeInfo (def->Name (), node_info_vec) );
3388
+ status = session_state_->GetOutputNodeInfo (def->Name (), node_info_vec);
3388
3389
} else {
3389
- ORT_RETURN_IF_ERROR ( session_state_->GetInputNodeInfo (def->Name (), node_info_vec) );
3390
+ status = session_state_->GetInputNodeInfo (def->Name (), node_info_vec);
3390
3391
}
3391
3392
3392
- // all entries are for the same OrtDevice so use the first one.
3393
- // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
3394
- // from the session state and use its OrtMemoryInfo.
3395
- auto allocator = session_state_->GetAllocator (*node_info_vec.front ().device );
3396
- memory_info.push_back (&allocator->Info ());
3393
+ if (!status.IsOK ()) {
3394
+ if (type == SessionInputOutputType::kInput ) {
3395
+ return status;
3396
+ }
3397
+
3398
+ // Check first if this output is produced by an input that directly
3399
+ // propagates to output with the same name.
3400
+ status = session_state_->GetInputNodeInfo (def->Name (), node_info_vec);
3401
+ if (status.IsOK ()) {
3402
+ // all entries are for the same OrtDevice so use the first one.
3403
+ // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
3404
+ // from the session state and use its OrtMemoryInfo.
3405
+ auto allocator = session_state_->GetAllocator (*node_info_vec.front ().device );
3406
+ memory_info.push_back (&allocator->Info ());
3407
+ } else {
3408
+ // Check if this output is produced by a constant initializer
3409
+ // Pick the MemoryInfo from the initializer's OrtValue
3410
+ const auto & ort_value_map = session_state_->GetOrtValueNameIdxMap ();
3411
+
3412
+ OrtValueIndex ort_value_index;
3413
+ status = ort_value_map.GetIdx (def->Name (), ort_value_index);
3414
+ if (!status.IsOK ()) {
3415
+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL,
3416
+ " Failed to find node output or a constant initializer producing output: " ,
3417
+ def->Name (), " ." );
3418
+ }
3419
+
3420
+ const auto & idx_to_ort_value = session_state_->GetInitializedTensors ();
3421
+ auto it = idx_to_ort_value.find (ort_value_index);
3422
+ if (it == idx_to_ort_value.end ()) {
3423
+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL,
3424
+ " Failed to find node output or a constant initializer producing output: " ,
3425
+ def->Name (), " ." );
3426
+ }
3427
+ const auto & tensor = it->second .Get <Tensor>();
3428
+ auto allocator = session_state_->GetAllocator (tensor.Location ());
3429
+ memory_info.push_back (&allocator->Info ());
3430
+ }
3431
+ } else {
3432
+ // all entries are for the same OrtDevice so use the first one.
3433
+ // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
3434
+ // from the session state and use its OrtMemoryInfo.
3435
+ auto allocator = session_state_->GetAllocator (*node_info_vec.front ().device );
3436
+ memory_info.push_back (&allocator->Info ());
3437
+ }
3397
3438
}
3398
3439
3399
3440
return Status::OK ();
@@ -3422,15 +3463,19 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector<const OrtEpD
3422
3463
for (const auto * def : def_list) {
3423
3464
InlinedVector<SessionState::NodeInfo> node_info_vec;
3424
3465
ORT_RETURN_IF_ERROR (session_state_->GetInputNodeInfo (def->Name (), node_info_vec));
3425
-
3426
- // if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map
3427
- // instead of doing a linear search each time.
3428
- const auto & ep_name = node_info_vec.front ().p_node ->GetExecutionProviderType ();
3429
- auto it = std::find_if (available_eps.begin (), available_eps.end (), [&ep_name](const OrtEpDevice* entry) {
3430
- return entry->ep_name == ep_name;
3431
- });
3432
-
3433
- ep_devices.push_back (it != available_eps.end () ? *it : nullptr );
3466
+ assert (!node_info_vec.empty ());
3467
+ // If we have an input that is not consumed by any node,
3468
+ // including nodes in subgraphs, then we return nullptr.
3469
+ const auto * p_node = node_info_vec.front ().p_node ;
3470
+ if (p_node != nullptr ) {
3471
+ const auto ep_name = p_node->GetExecutionProviderType ();
3472
+ auto it = std::find_if (available_eps.begin (), available_eps.end (), [&ep_name](const OrtEpDevice* entry) {
3473
+ return entry->ep_name == ep_name;
3474
+ });
3475
+ ep_devices.push_back (it != available_eps.end () ? *it : nullptr );
3476
+ } else {
3477
+ ep_devices.push_back (nullptr );
3478
+ }
3434
3479
}
3435
3480
3436
3481
return Status::OK ();
0 commit comments