1+ // Modifications (c) Copyright 2023-2025 Advanced Micro Devices, Inc. or its
2+ // affiliates
13// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s
24
35// Float multiplications
@@ -15,7 +17,7 @@ func.func @mul_fold_float() -> tensor<4xf16> {
1517 dense <[-132.7 , -3.0 , -0.0 , 5.0 ]> :
1618 tensor <4 xf16 >
1719 } : () -> tensor <4 xf16 >
18- %2 = " tosa.mul" (%0 , %1 ) { shift = 0 : i8 } : (tensor <4 xf16 >, tensor <4 xf16 >) -> tensor <4 xf16 >
20+ %2 = " tosa.mul" (%0 , %1 ) : (tensor <4 xf16 >, tensor <4 xf16 >) -> tensor <4 xf16 >
1921 return %2 : tensor <4 xf16 >
2022}
2123
@@ -32,7 +34,7 @@ func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> {
3234 dense <[3.0 , -3.0 , -3.0 , 3.0 , 1.0 , 0xFF800000 , 0.0 ]> :
3335 tensor <7 xf32 >
3436 } : () -> tensor <7 xf32 >
35- %2 = " tosa.mul" (%0 , %1 ) { shift = 0 : i8 } : (tensor <7 xf32 >, tensor <7 xf32 >) -> tensor <7 xf32 >
37+ %2 = " tosa.mul" (%0 , %1 ) : (tensor <7 xf32 >, tensor <7 xf32 >) -> tensor <7 xf32 >
3638 return %2 : tensor <7 xf32 >
3739}
3840
@@ -49,7 +51,7 @@ func.func @add_fold_float_overflow() -> tensor<2xf32> {
4951 dense <[2.1e+38 , 1.1e+38 ]> :
5052 tensor <2 xf32 >
5153 } : () -> tensor <2 xf32 >
52- %2 = " tosa.mul" (%0 , %1 ) { shift = 0 : i8 } : (tensor <2 xf32 >, tensor <2 xf32 >) -> tensor <2 xf32 >
54+ %2 = " tosa.mul" (%0 , %1 ) : (tensor <2 xf32 >, tensor <2 xf32 >) -> tensor <2 xf32 >
5355 return %2 : tensor <2 xf32 >
5456}
5557
@@ -69,7 +71,8 @@ func.func @mul_fold_int() -> tensor<4xi32> {
6971 dense <[-132 , -3 , 0 , 5 ]> :
7072 tensor <4 xi32 >
7173 } : () -> tensor <4 xi32 >
72- %2 = " tosa.mul" (%0 , %1 ) {shift = 0 : i8 } : (tensor <4 xi32 >, tensor <4 xi32 >) -> tensor <4 xi32 >
74+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
75+ %2 = " tosa.mul" (%0 , %1 , %shift ) : (tensor <4 xi32 >, tensor <4 xi32 >, tensor <1 xi8 >) -> tensor <4 xi32 >
7376 return %2 : tensor <4 xi32 >
7477}
7578
@@ -87,10 +90,12 @@ func.func @mul_fold_i8() -> tensor<4xi32> {
8790 tensor <4 xi8 >
8891 } : () -> tensor <4 xi8 >
8992 // TODO: This is wrongly rejected as illegal, see https://reviews.llvm.org/D150472#4484478
90- // %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi32>
93+ // %zero_shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
94+ // %2 = "tosa.mul"(%0, %1, %zero_shift) : (tensor<4xi8>, tensor<4xi8>, tensor<1xi8>) -> tensor<4xi32>
9195 %a = " tosa.cast" (%0 ) : (tensor <4 xi8 >) -> tensor <4 xi32 >
9296 %b = " tosa.cast" (%1 ) : (tensor <4 xi8 >) -> tensor <4 xi32 >
93- %2 = " tosa.mul" (%a , %b ) {shift = 0 : i8 } : (tensor <4 xi32 >, tensor <4 xi32 >) -> tensor <4 xi32 >
97+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
98+ %2 = " tosa.mul" (%a , %b , %shift ) : (tensor <4 xi32 >, tensor <4 xi32 >, tensor <1 xi8 >) -> tensor <4 xi32 >
9499
95100 return %2 : tensor <4 xi32 >
96101}
@@ -110,8 +115,9 @@ func.func @mul_fold_int_overflow() -> tensor<4xi32> {
110115 dense <[1 , 10 , 1 , 30 ]> :
111116 tensor <4 xi32 >
112117 } : () -> tensor <4 xi32 >
118+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
113119 // expected-warning@below {{Multiplication did overflow. The results are unspecified.}}
114- %2 = " tosa.mul" (%0 , %1 ) { shift = 0 : i8 } : (tensor <4 xi32 >, tensor <4 xi32 >) -> tensor <4 xi32 >
120+ %2 = " tosa.mul" (%0 , %1 , % shift) : (tensor <4 xi32 >, tensor <4 xi32 >, tensor < 1 x i8 >) -> tensor <4 xi32 >
115121 return %2 : tensor <4 xi32 >
116122}
117123
@@ -127,7 +133,8 @@ func.func @mul_fold_equal_args() -> tensor<3xi32> {
127133 dense <[-17 , 4 , 0 ]> :
128134 tensor <3 xi32 >
129135 } : () -> tensor <3 xi32 >
130- %2 = " tosa.mul" (%0 , %0 ) {shift = 0 : i8 } : (tensor <3 xi32 >, tensor <3 xi32 >) -> tensor <3 xi32 >
136+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
137+ %2 = " tosa.mul" (%0 , %0 , %shift ) : (tensor <3 xi32 >, tensor <3 xi32 >, tensor <1 xi8 >) -> tensor <3 xi32 >
131138 return %2 : tensor <3 xi32 >
132139}
133140
@@ -147,7 +154,8 @@ func.func @mul_fold_int_broadcast_simple() -> tensor<3xi32> {
147154 dense <-12 > :
148155 tensor <1 xi32 >
149156 } : () -> tensor <1 xi32 >
150- %2 = " tosa.mul" (%0 , %1 ) {shift = 0 : i8 } : (tensor <3 xi32 >, tensor <1 xi32 >) -> tensor <3 xi32 >
157+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
158+ %2 = " tosa.mul" (%0 , %1 , %shift ) : (tensor <3 xi32 >, tensor <1 xi32 >, tensor <1 xi8 >) -> tensor <3 xi32 >
151159 return %2 : tensor <3 xi32 >
152160}
153161
@@ -167,15 +175,17 @@ func.func @mul_fold_int_broadcast_complex() -> tensor<3x3xi32> {
167175 dense <[[-12 , 7 , 4 ]]> :
168176 tensor <1 x3 xi32 >
169177 } : () -> tensor <1 x3 xi32 >
170- %2 = " tosa.mul" (%0 , %1 ) {shift = 0 : i8 } : (tensor <3 x1 xi32 >, tensor <1 x3 xi32 >) -> tensor <3 x3 xi32 >
178+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
179+ %2 = " tosa.mul" (%0 , %1 , %shift ) : (tensor <3 x1 xi32 >, tensor <1 x3 xi32 >, tensor <1 xi8 >) -> tensor <3 x3 xi32 >
171180 return %2 : tensor <3 x3 xi32 >
172181}
173182
174183// CHECK-LABEL: @mul_fold_int_non_zero_shift
175184func.func @mul_fold_int_non_zero_shift () -> tensor <4 xi32 > {
176- // CHECK: [[FIRST:]] ={{.*}}tosa.const
177- // CHECK-NEXT: [[SECOND:]] ={{.*}}tosa.const
178- // CHECK-NEXT: [[MUL:]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]]
185+ // CHECK: [[FIRST:%.*]] ={{.*}}tosa.const
186+ // CHECK-NEXT: [[SECOND:%.*]] ={{.*}}tosa.const
187+ // CHECK-NEXT: [[SHIFT:%.*]] ={{.*}}tosa.const
188+ // CHECK-NEXT: [[MUL:%.*]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]], [[SHIFT]]
179189 // CHECK-NEXT: return [[MUL]]
180190 %0 = " tosa.const" () {value =
181191 dense <[-17 , 4 , 0 , 0 ]> :
@@ -185,6 +195,7 @@ func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> {
185195 dense <[-132 , -3 , 0 , 5 ]> :
186196 tensor <4 xi32 >
187197 } : () -> tensor <4 xi32 >
188- %2 = " tosa.mul" (%0 , %1 ) {shift = 1 : i8 } : (tensor <4 xi32 >, tensor <4 xi32 >) -> tensor <4 xi32 >
198+ %shift = " tosa.const" () <{value = dense <1 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
199+ %2 = " tosa.mul" (%0 , %1 , %shift ) : (tensor <4 xi32 >, tensor <4 xi32 >, tensor <1 xi8 >) -> tensor <4 xi32 >
189200 return %2 : tensor <4 xi32 >
190201}
0 commit comments