Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ static Status GetInputIndices(const EpNode& consumer_node,
[&found, &value_info_name, &indices](gsl::span<const EpValueInfo* const> input_value_infos,
bool is_implicit) -> void {
for (size_t i = 0; i < input_value_infos.size(); i++) {
if (input_value_infos[i] == nullptr) { // input_value_info == nullptr means the input is optional
if (input_value_infos[i] == nullptr) {
// This is a missing optional input. Skip it.
continue;
}
if (input_value_infos[i]->GetName() == value_info_name) {
Expand All @@ -351,6 +352,11 @@ static Status GetOutputIndex(const EpNode& producer_node,
gsl::span<const EpValueInfo* const> outputs = producer_node.GetOutputsSpan();

for (size_t i = 0; i < outputs.size(); i++) {
if (outputs[i] == nullptr) {
// This is a missing optional output. Skip it.
continue;
}

if (outputs[i]->GetName() == value_info_name) {
index = i;
found = true;
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ TEST(EpGraphTest, CheckModelExternalInitializers) {
CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph());
}

TEST(EpGraphTest, CheckModelOptionalIntermediateNodeOutputs) {
auto test_graph = TestGraph::Load(ORT_TSTR("testdata/skip_simplified_layer_normalization.onnx"));
ASSERT_NE(test_graph, nullptr) << "Failed to load test model";

CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph());
}

static void RunConvQDQExtIni(const ORTCHAR_T* model_path, std::vector<float>& output_data) {
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::SessionOptions sess_options;
Expand Down
Binary file not shown.
65 changes: 65 additions & 0 deletions onnxruntime/test/testdata/skip_simplified_layer_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from onnx import TensorProto, checker, helper, save, shape_inference

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning test

Run lintrunner -a to apply this patch.

batch_size = 1
seq_len = 64
hidden_size = 896

input_vi = helper.make_tensor_value_info(
name="input",
elem_type=TensorProto.FLOAT,
shape=[batch_size, seq_len, hidden_size],
)

skip_vi = helper.make_tensor_value_info(
name="skip",
elem_type=TensorProto.FLOAT,
shape=[batch_size, seq_len, hidden_size],
)

output_vi = helper.make_tensor_value_info(
name="output",
elem_type=TensorProto.FLOAT,
shape=[batch_size, seq_len, hidden_size],
)

input_skip_bias_sum_vi = helper.make_tensor_value_info(
name="input_skip_bias_sum",
elem_type=TensorProto.FLOAT,
shape=[batch_size, seq_len, hidden_size],
)

gamma_init = helper.make_tensor(
name="gamma",
data_type=TensorProto.FLOAT,
dims=[hidden_size],
vals=[1] * hidden_size
)

node = helper.make_node(
op_type="SkipSimplifiedLayerNormalization",
inputs=["input", "skip", "gamma"],
outputs=["output", "", "", "input_skip_bias_sum"],
domain="com.microsoft",
epsilon=1e-6,
name="SkipLayerNorm",
)

graph = helper.make_graph(
nodes=[node],
name="SkipSimplifiedLayerNormGraph",
inputs=[input_vi, skip_vi],
outputs=[output_vi, input_skip_bias_sum_vi],
initializer=[gamma_init],
)

model = helper.make_model(
graph,
opset_imports=[
helper.make_operatorsetid("", 17),
helper.make_operatorsetid("com.microsoft", 1),
],
)

model = shape_inference.infer_shapes(model)
checker.check_model(model, True)
save(model, "skip_simplified_layer_normalization.onnx")