Skip to content

Commit d251f3a

Browse files
authored
Address edge GetMemInfo edge cases (#26021)
### Description <!-- Describe your changes. --> This fixes somewhat contrived edgecases that are present in our tests - input propagates to output - output is produced by an initializer. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Python API upcoming PR does not pass tests without it.
1 parent 57aec2a commit d251f3a

File tree

6 files changed

+202
-30
lines changed

6 files changed

+202
-30
lines changed

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,11 +1582,13 @@ inline std::vector<ConstMemoryInfo> ConstSessionImpl<T>::GetMemoryInfoForInputs(
15821582

15831583
auto num_inputs = GetInputCount();
15841584
std::vector<ConstMemoryInfo> mem_infos;
1585-
mem_infos.resize(num_inputs);
1585+
if (num_inputs > 0) {
1586+
mem_infos.resize(num_inputs);
15861587

1587-
ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_,
1588-
reinterpret_cast<const OrtMemoryInfo**>(mem_infos.data()),
1589-
num_inputs));
1588+
ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_,
1589+
reinterpret_cast<const OrtMemoryInfo**>(mem_infos.data()),
1590+
num_inputs));
1591+
}
15901592

15911593
return mem_infos;
15921594
}
@@ -1598,11 +1600,13 @@ inline std::vector<ConstMemoryInfo> ConstSessionImpl<T>::GetMemoryInfoForOutputs
15981600

15991601
auto num_outputs = GetOutputCount();
16001602
std::vector<ConstMemoryInfo> mem_infos;
1601-
mem_infos.resize(num_outputs);
1603+
if (num_outputs > 0) {
1604+
mem_infos.resize(num_outputs);
16021605

1603-
ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_,
1604-
reinterpret_cast<const OrtMemoryInfo**>(mem_infos.data()),
1605-
num_outputs));
1606+
ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_,
1607+
reinterpret_cast<const OrtMemoryInfo**>(mem_infos.data()),
1608+
num_outputs));
1609+
}
16061610
return mem_infos;
16071611
}
16081612

@@ -1631,12 +1635,12 @@ template <typename T>
16311635
inline std::vector<ConstEpDevice> ConstSessionImpl<T>::GetEpDeviceForInputs() const {
16321636
auto num_inputs = GetInputCount();
16331637
std::vector<ConstEpDevice> input_devices;
1634-
input_devices.resize(num_inputs);
1635-
1636-
ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_,
1637-
reinterpret_cast<const OrtEpDevice**>(input_devices.data()),
1638-
num_inputs));
1639-
1638+
if (num_inputs > 0) {
1639+
input_devices.resize(num_inputs);
1640+
ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_,
1641+
reinterpret_cast<const OrtEpDevice**>(input_devices.data()),
1642+
num_inputs));
1643+
}
16401644
return input_devices;
16411645
}
16421646

onnxruntime/core/session/inference_session.cc

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,17 +3383,58 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType
33833383

33843384
for (const auto* def : def_list) {
33853385
InlinedVector<SessionState::NodeInfo> node_info_vec;
3386+
Status status;
33863387
if (type == SessionInputOutputType::kOutput) {
3387-
ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec));
3388+
status = session_state_->GetOutputNodeInfo(def->Name(), node_info_vec);
33883389
} else {
3389-
ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec));
3390+
status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec);
33903391
}
33913392

3392-
// all entries are for the same OrtDevice so use the first one.
3393-
// we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
3394-
// from the session state and use its OrtMemoryInfo.
3395-
auto allocator = session_state_->GetAllocator(*node_info_vec.front().device);
3396-
memory_info.push_back(&allocator->Info());
3393+
if (!status.IsOK()) {
3394+
if (type == SessionInputOutputType::kInput) {
3395+
return status;
3396+
}
3397+
3398+
// Check first if this output is produced by an input that directly
3399+
// propagates to output with the same name.
3400+
status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec);
3401+
if (status.IsOK()) {
3402+
// all entries are for the same OrtDevice so use the first one.
3403+
// we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
3404+
// from the session state and use its OrtMemoryInfo.
3405+
auto allocator = session_state_->GetAllocator(*node_info_vec.front().device);
3406+
memory_info.push_back(&allocator->Info());
3407+
} else {
3408+
// Check if this output is produced by a constant initializer
3409+
// Pick the MemoryInfo from the initializer's OrtValue
3410+
const auto& ort_value_map = session_state_->GetOrtValueNameIdxMap();
3411+
3412+
OrtValueIndex ort_value_index;
3413+
status = ort_value_map.GetIdx(def->Name(), ort_value_index);
3414+
if (!status.IsOK()) {
3415+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
3416+
"Failed to find node output or a constant initializer producing output: ",
3417+
def->Name(), ".");
3418+
}
3419+
3420+
const auto& idx_to_ort_value = session_state_->GetInitializedTensors();
3421+
auto it = idx_to_ort_value.find(ort_value_index);
3422+
if (it == idx_to_ort_value.end()) {
3423+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
3424+
"Failed to find node output or a constant initializer producing output: ",
3425+
def->Name(), ".");
3426+
}
3427+
const auto& tensor = it->second.Get<Tensor>();
3428+
auto allocator = session_state_->GetAllocator(tensor.Location());
3429+
memory_info.push_back(&allocator->Info());
3430+
}
3431+
} else {
3432+
// all entries are for the same OrtDevice so use the first one.
3433+
// we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
3434+
// from the session state and use its OrtMemoryInfo.
3435+
auto allocator = session_state_->GetAllocator(*node_info_vec.front().device);
3436+
memory_info.push_back(&allocator->Info());
3437+
}
33973438
}
33983439

33993440
return Status::OK();
@@ -3422,15 +3463,19 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector<const OrtEpD
34223463
for (const auto* def : def_list) {
34233464
InlinedVector<SessionState::NodeInfo> node_info_vec;
34243465
ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec));
3425-
3426-
// if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map
3427-
// instead of doing a linear search each time.
3428-
const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType();
3429-
auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) {
3430-
return entry->ep_name == ep_name;
3431-
});
3432-
3433-
ep_devices.push_back(it != available_eps.end() ? *it : nullptr);
3466+
assert(!node_info_vec.empty());
3467+
// If we have an input that is not consumed by any node,
3468+
// including nodes in subgraphs, then we return nullptr.
3469+
const auto* p_node = node_info_vec.front().p_node;
3470+
if (p_node != nullptr) {
3471+
const auto ep_name = p_node->GetExecutionProviderType();
3472+
auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) {
3473+
return entry->ep_name == ep_name;
3474+
});
3475+
ep_devices.push_back(it != available_eps.end() ? *it : nullptr);
3476+
} else {
3477+
ep_devices.push_back(nullptr);
3478+
}
34343479
}
34353480

34363481
return Status::OK();

onnxruntime/test/shared_lib/test_inference.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,35 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders,
494494
CApiTestWithProvider,
495495
::testing::Values(0, 1, 2, 3, 4));
496496

497+
TEST(CApiTest, TestInputPassThroughToOutput) {
498+
const ORTCHAR_T* model_uri = TSTR("testdata/input_propagated_to_output.onnx");
499+
Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{});
500+
auto inputs_meminfos = session.GetMemoryInfoForInputs();
501+
ASSERT_EQ(1U, inputs_meminfos.size());
502+
auto inputs_epdevices = session.GetEpDeviceForInputs();
503+
ASSERT_EQ(1U, inputs_epdevices.size());
504+
auto outputs_meminfos = session.GetMemoryInfoForOutputs();
505+
ASSERT_EQ(7U, outputs_meminfos.size());
506+
}
507+
508+
TEST(CApiTest, TestDanglingInput) {
509+
// Here we test an issue with segments_ids that is an input not consumed by anything
510+
// This kind of model is unlikely to be used in practice but we want to make sure it works
511+
const ORTCHAR_T* model_uri = TSTR("testdata/test_dangling_input_segment_ids.onnx");
512+
Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{});
513+
auto inputs_meminfos = session.GetMemoryInfoForInputs();
514+
ASSERT_EQ(2U, inputs_meminfos.size());
515+
auto outputs_meminfos = session.GetMemoryInfoForOutputs();
516+
ASSERT_EQ(2U, outputs_meminfos.size());
517+
auto inputs_epdevices = session.GetEpDeviceForInputs();
518+
ASSERT_EQ(2U, inputs_epdevices.size());
519+
// One of the devices returning is null since the input is not consumed
520+
// there is not a device for it.
521+
const bool null_present = std::any_of(inputs_epdevices.begin(), inputs_epdevices.end(),
522+
[](const auto& device) { return device == nullptr; });
523+
ASSERT_TRUE(null_present);
524+
}
525+
497526
#if !defined(DISABLE_SPARSE_TENSORS)
498527
TEST(CApiTest, SparseOutputModel) {
499528
std::vector<int64_t> dense_shape{3, 3};
@@ -505,7 +534,15 @@ TEST(CApiTest, SparseOutputModel) {
505534
std::vector<Ort::Value> ort_inputs;
506535
std::vector<const char*> input_names;
507536
const char* const output_names[] = {"values"};
537+
// This model produces a sparse output from a constant sparse initializer
508538
Ort::Session session(*ort_env, SPARSE_OUTPUT_MODEL_URI, Ort::SessionOptions{});
539+
auto inputs_meminfos = session.GetMemoryInfoForInputs();
540+
ASSERT_TRUE(inputs_meminfos.empty());
541+
auto outputs_meminfos = session.GetMemoryInfoForOutputs();
542+
ASSERT_EQ(1U, outputs_meminfos.size());
543+
auto inputs_epdevices = session.GetEpDeviceForInputs();
544+
ASSERT_TRUE(inputs_epdevices.empty());
545+
509546
auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
510547
output_names, 1);
511548
ASSERT_EQ(ort_outputs.size(), 1U);
854 Bytes
Binary file not shown.
1.15 KB
Binary file not shown.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Run this script to recreate the original onnx model.
3+
Example usage:
4+
python test_dangling_input_segment_ids.py out_model_path.onnx
5+
"""
6+
7+
import os
8+
import sys
9+
10+
import numpy as np
11+
import onnx
12+
from onnx import TensorProto, helper, numpy_helper
13+
14+
DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids")
15+
16+
17+
def order_repeated_field(repeated_proto, key_name, order):
18+
order = list(order)
19+
repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))
20+
21+
22+
def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
23+
node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
24+
if doc_string == "":
25+
node.doc_string = ""
26+
order_repeated_field(node.attribute, "name", kwargs.keys())
27+
return node
28+
29+
30+
def make_graph(*args, doc_string=None, **kwargs):
31+
graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
32+
if doc_string == "":
33+
graph.doc_string = ""
34+
return graph
35+
36+
37+
model = helper.make_model(
38+
opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)],
39+
ir_version=7,
40+
graph=make_graph(
41+
name="embed_layernorm_graph",
42+
inputs=[
43+
helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]),
44+
helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]),
45+
],
46+
outputs=[
47+
helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]),
48+
helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]),
49+
],
50+
initializer=[
51+
numpy_helper.from_array(
52+
np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]),
53+
name="word_embed",
54+
),
55+
numpy_helper.from_array(
56+
np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]),
57+
name="pos_embed",
58+
),
59+
numpy_helper.from_array(
60+
np.array(
61+
[0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495],
62+
dtype="float32",
63+
),
64+
name="gamma",
65+
),
66+
numpy_helper.from_array(
67+
np.array(
68+
[0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32"
69+
),
70+
name="beta",
71+
),
72+
],
73+
nodes=[
74+
make_node(
75+
"EmbedLayerNormalization",
76+
inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"],
77+
outputs=["layernorm_out", "mask_index_out"],
78+
domain="com.microsoft",
79+
)
80+
],
81+
),
82+
)
83+
84+
if __name__ == "__main__" and len(sys.argv) == 2:
85+
_, out_path = sys.argv
86+
onnx.save(model, out_path)

0 commit comments

Comments
 (0)