Skip to content

Conversation

@lhutton1
Copy link
Contributor

This commit allows zero-points used by a number of tosa operations to be unranked. This allows the shape inference pass to propagate shape information.

@llvmbot
Copy link
Member

llvmbot commented Jun 11, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

This commit allows zero-points used by a number of tosa operations to be unranked. This allows the shape inference pass to propagate shape information.


Full diff: https://github.com/llvm/llvm-project/pull/143770.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+1-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+13)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 536551c8f8437..1cfe6eee576b3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -152,7 +152,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
 
 def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
 def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
-def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>;
+def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
 
 // We include unranked tensors as a supported type for all possible tosa
 // Tensors as unranked does not guarantee invalid. If unranked tensors exist
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a4617fc6fba8b..9982e1a7fe197 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1007,7 +1007,7 @@ func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1
 func.func @test_conv2d_rank0_zp(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi32> {
   %input_zp = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<i8>
   %weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  // expected-error@+1 {{'tosa.conv2d' op operand #3 must be tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor<i8>'}}
+  // expected-error@+1 {{'tosa.conv2d' op operand #3 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor<i8>'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<i8>, tensor<1xi8>) -> tensor<1x27x27x16xi32>
   return %0 : tensor<1x27x27x16xi32>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 591a3f0acf65d..f5ab4a2241358 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -333,6 +333,19 @@ func.func @test_dynamic_mixed_matmul(%arg0 : tensor<?x3x?xi32>, %arg1 : tensor<?
 
 // -----
 
+// CHECK-LABEL: @test_unranked_zero_points_matmul
+func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x3x4xf32>) -> tensor<1x2x4xf32> {
+    %a_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+    %b_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+    %a_zp_cast = "tosa.cast"(%a_zp) : (tensor<1xi8>) -> tensor<*xf32>
+    %b_zp_cast = "tosa.cast"(%b_zp) : (tensor<1xi8>) -> tensor<*xf32>
+    // CHECK: tosa.matmul %arg0, %arg1, %2, %3 : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x4xf32>
+    %0 = tosa.matmul %arg0, %arg1, %a_zp_cast, %b_zp_cast : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<*xf32>, tensor<*xf32>)  -> tensor<1x2x4xf32>
+    return %0 : tensor<1x2x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @test_table_static
 func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
   // CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>

Copy link
Contributor

@psunn psunn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Just one minor nit.

This commit allows zero-points used by a number of
tosa operations to be unranked. This allows the shape
inference pass to propagate shape information.

Change-Id: I20c1a04eb2d306f181ffd5c1574f69fd9e410102
@lhutton1 lhutton1 force-pushed the unranked-zero-point branch from 015ab39 to 5f9f404 Compare June 16, 2025 14:02
@lhutton1 lhutton1 merged commit 98a6fed into llvm:main Jun 23, 2025
7 checks passed
@lhutton1 lhutton1 deleted the unranked-zero-point branch June 23, 2025 08:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants