diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index d30c7cd74a76a..c469873ba5c06 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -327,7 +327,8 @@ static Status GetInputIndices(const EpNode& consumer_node, [&found, &value_info_name, &indices](gsl::span input_value_infos, bool is_implicit) -> void { for (size_t i = 0; i < input_value_infos.size(); i++) { - if (input_value_infos[i] == nullptr) { // input_value_info == nullptr means the input is optional + if (input_value_infos[i] == nullptr) { + // This is a missing optional input. Skip it. continue; } if (input_value_infos[i]->GetName() == value_info_name) { @@ -351,6 +352,11 @@ static Status GetOutputIndex(const EpNode& producer_node, gsl::span outputs = producer_node.GetOutputsSpan(); for (size_t i = 0; i < outputs.size(); i++) { + if (outputs[i] == nullptr) { + // This is a missing optional output. Skip it. + continue; + } + if (outputs[i]->GetName() == value_info_name) { index = i; found = true; diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 055b2551328d9..a15f36014e232 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -180,6 +180,13 @@ TEST(EpGraphTest, CheckModelExternalInitializers) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, CheckModelOptionalIntermediateNodeOutputs) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/skip_simplified_layer_normalization.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + static void RunConvQDQExtIni(const ORTCHAR_T* model_path, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; diff --git a/onnxruntime/test/testdata/skip_simplified_layer_normalization.onnx b/onnxruntime/test/testdata/skip_simplified_layer_normalization.onnx new file mode 100644 index 0000000000000..a9adf07ab4a69 Binary files /dev/null and b/onnxruntime/test/testdata/skip_simplified_layer_normalization.onnx differ diff --git a/onnxruntime/test/testdata/skip_simplified_layer_normalization.py b/onnxruntime/test/testdata/skip_simplified_layer_normalization.py new file mode 100644 index 0000000000000..5d040231a7c36 --- /dev/null +++ b/onnxruntime/test/testdata/skip_simplified_layer_normalization.py @@ -0,0 +1,65 @@ +from onnx import TensorProto, checker, helper, save, shape_inference + +batch_size = 1 +seq_len = 64 +hidden_size = 896 + +input_vi = helper.make_tensor_value_info( + name="input", + elem_type=TensorProto.FLOAT, + shape=[batch_size, seq_len, hidden_size], +) + +skip_vi = helper.make_tensor_value_info( + name="skip", + elem_type=TensorProto.FLOAT, + shape=[batch_size, seq_len, hidden_size], +) + +output_vi = helper.make_tensor_value_info( + name="output", + elem_type=TensorProto.FLOAT, + shape=[batch_size, seq_len, hidden_size], +) + +input_skip_bias_sum_vi = helper.make_tensor_value_info( + name="input_skip_bias_sum", + elem_type=TensorProto.FLOAT, + shape=[batch_size, seq_len, hidden_size], +) + +gamma_init = helper.make_tensor( + name="gamma", + data_type=TensorProto.FLOAT, + dims=[hidden_size], + vals=[1] * hidden_size +) + +node = helper.make_node( + op_type="SkipSimplifiedLayerNormalization", + inputs=["input", "skip", "gamma"], + outputs=["output", "", "", "input_skip_bias_sum"], + domain="com.microsoft", + epsilon=1e-6, + name="SkipLayerNorm", +) + +graph = helper.make_graph( + nodes=[node], + name="SkipSimplifiedLayerNormGraph", + inputs=[input_vi, skip_vi], + outputs=[output_vi, input_skip_bias_sum_vi], + initializer=[gamma_init], +) + +model = helper.make_model( + graph, + opset_imports=[ + helper.make_operatorsetid("", 17), + helper.make_operatorsetid("com.microsoft", 1), + ], +) + +model = shape_inference.infer_shapes(model) +checker.check_model(model, True) +save(model, "skip_simplified_layer_normalization.onnx")