Skip to content

Commit 86d29ea

Browse files
authored
Merge pull request #442 from Xilinx/planzase.AIESW-13111.gather_check_for_static_shape
[AIESW-13111] Check Gather input shape is static
2 parents 8d72440 + 6d54700 commit 86d29ea

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

src/Conversion/ONNXToTOSA/Tensor/Gather.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
1818
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
1919
#include "src/Dialect/ONNX/ONNXOps.hpp"
20+
#include "src/Support/TypeUtilities.hpp"
2021
#include "llvm/ADT/SmallVector.h"
2122

2223
using namespace mlir;
@@ -46,7 +47,8 @@ class ONNXGatherLoweringToTOSA : public OpConversionPattern<ONNXGatherOp> {
4647
if (!onnx_mlir::isRankedShapedType(inputType))
4748
return rewriter.notifyMatchFailure(op, "input is not a ranked tensor");
4849

49-
if (!hasStaticShape(result.getType()))
50+
if (!hasStaticShape(inputType) || !hasStaticShape(indices.getType()) ||
51+
!hasStaticShape(result.getType()))
5052
return rewriter.notifyMatchFailure(op, "dynamic shapes not supported");
5153

5254
auto resultTy = dyn_cast<TensorType>(op.getType());

src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ Now, it's straightforward to update the output shape of Reshape from
6565
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
6666
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
6767
#include "src/Pass/Passes.hpp"
68-
#include "src/Support/TypeUtilities.hpp"
6968

7069
#define DEBUG_TYPE "simplify_shape_related_ops"
7170

test/mlir/conversion/onnx_to_tosa/Tensor/Gather.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,21 @@ func.func @test_gather_dynamic_shape_indices_i32(%arg0 : tensor<?x4xf32>, %indic
175175
// CHECK-LABEL: test_gather_dynamic_shape_indices_i32
176176
// CHECK: onnx.Gather
177177
}
178+
179+
// -----
180+
181+
func.func @test_gather_dynamic_input_static_output(%arg0 : tensor<?x2xf32>, %indices: tensor<?xi64>) -> tensor<1x2xf32> {
182+
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64, onnx_node_name = "/Gather_16"} : (tensor<?x2xf32>, tensor<?xi64>) -> tensor<1x2xf32>
183+
"func.return"(%0) : (tensor<1x2xf32>) -> ()
184+
// CHECK-LABEL: test_gather_dynamic_input_static_output
185+
// CHECK: onnx.Gather
186+
}
187+
188+
// -----
189+
190+
func.func @test_gather_dynamic_indices(%arg0 : tensor<1x2xf32>, %indices: tensor<?xi64>) -> tensor<1x2xf32> {
191+
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64, onnx_node_name = "/Gather_16"} : (tensor<1x2xf32>, tensor<?xi64>) -> tensor<1x2xf32>
192+
"func.return"(%0) : (tensor<1x2xf32>) -> ()
193+
// CHECK-LABEL: test_gather_dynamic_indices
194+
// CHECK: onnx.Gather
195+
}

0 commit comments

Comments
 (0)