Skip to content

Commit 5bd1ff0

Browse files
authored
Fix transpose of empty initializers called from constant folding (microsoft#25922)
### Description ConstantFolding crashes when a constant empty initializer is transposed. ### Motivation and Context This rare case happened when converting a LLM into onnx with static shape.
1 parent 9d650a4 commit 5bd1ff0

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,9 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector<int
575575
TensorShape new_tensor_shape(new_tensor_shape_dims);
576576
Tensor out_tensor(tensor_dtype, new_tensor_shape, cpu_allocator_);
577577

578-
ORT_THROW_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor));
578+
if (new_tensor_shape.Size() > 0) {
579+
ORT_THROW_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor));
580+
}
579581

580582
auto& node_arg = *graph_.GetNodeArg(name_str);
581583
TensorShapeProto new_shape;

onnxruntime/test/optimizer/graph_transform_test.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,24 @@ TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) {
631631
}
632632
}
633633

634+
TEST_F(GraphTransformationTests, ConstantFoldingTransposeEmptyInitializer) {
635+
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant_folding_transpose_empty_initializer.onnx";
636+
std::shared_ptr<Model> model;
637+
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
638+
Graph& graph = model->MainGraph();
639+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
640+
ASSERT_TRUE(op_to_count["Transpose"] == 1);
641+
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
642+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
643+
const ConfigOptions empty_config_options;
644+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
645+
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
646+
TransformerLevel::Level1));
647+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
648+
op_to_count = CountOpsInGraph(graph);
649+
ASSERT_TRUE(op_to_count["Transpose"] == 0);
650+
}
651+
634652
TEST_F(GraphTransformationTests, ConstantFoldingUnsupportedFloat16) {
635653
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant_float16_mul.onnx";
636654
std::shared_ptr<Model> model;
Binary file not shown.

0 commit comments

Comments
 (0)