Skip to content

Commit 6134cf9

Browse files
committed
add test for another broadcast slice case
1 parent 09e3e5e commit 6134cf9

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt=passses=131072 | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg6: tensor<1536xf64>, %3931: tensor<1519x3056xf64>) -> (tensor<1519x3056xf64>, tensor<1x1519x3056xf64>) {
5+
%3184 = stablehlo.slice %arg6 [9:10] : (tensor<1536xf64>) -> tensor<1xf64>
6+
%3193 = stablehlo.broadcast_in_dim %3184, dims = [0] : (tensor<1xf64>) -> tensor<3056xf64>
7+
8+
%3194 = stablehlo.reshape %3193 : (tensor<3056xf64>) -> tensor<1x1x3056xf64>
9+
%3195 = stablehlo.reshape %3193 : (tensor<3056xf64>) -> tensor<1x3056xf64>
10+
11+
%3189 = stablehlo.slice %arg6 [9:1528] : (tensor<1536xf64>) -> tensor<1519xf64>
12+
%3190 = stablehlo.broadcast_in_dim %3189, dims = [0] : (tensor<1519xf64>) -> tensor<1519x3056xf64>
13+
%3191 = stablehlo.reshape %3190 : (tensor<1519x3056xf64>) -> tensor<1x1519x3056xf64>
14+
15+
%3932 = stablehlo.reshape %3931 : (tensor<1519x3056xf64>) -> tensor<1x1519x3056xf64>
16+
17+
%3196 = stablehlo.slice %3190 [1:1519, 0:3056] : (tensor<1519x3056xf64>) -> tensor<1518x3056xf64>
18+
%3197 = stablehlo.concatenate %3195, %3196, dim = 0 : (tensor<1x3056xf64>, tensor<1518x3056xf64>) -> tensor<1519x3056xf64>
19+
20+
%3936 = stablehlo.reshape %3931 : (tensor<1519x3056xf64>) -> tensor<1519x1x3056xf64>
21+
%3937 = stablehlo.slice %3936 [0:1, 0:1, 0:3056] : (tensor<1519x1x3056xf64>) -> tensor<1x1x3056xf64>
22+
%3943 = stablehlo.multiply %3194, %3937 : tensor<1x1x3056xf64>
23+
24+
%3942 = stablehlo.multiply %3191, %3932 : tensor<1x1519x3056xf64>
25+
%3944 = stablehlo.slice %3942 [0:1, 1:1519, 0:3056] : (tensor<1x1519x3056xf64>) -> tensor<1x1518x3056xf64>
26+
27+
%3945 = stablehlo.concatenate %3943, %3944, dim = 1 : (tensor<1x1x3056xf64>, tensor<1x1518x3056xf64>) -> tensor<1x1519x3056xf64>
28+
29+
return %3197, %3945 : tensor<1519x3056xf64>, tensor<1x1519x3056xf64>
30+
}
31+
32+
// CHECK: func.func @main(%arg0: tensor<1536xf64>, %arg1: tensor<1519x3056xf64>) -> (tensor<1519x3056xf64>, tensor<1x1519x3056xf64>) {
33+
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [9:1528] : (tensor<1536xf64>) -> tensor<1519xf64>
34+
// CHECK-NEXT: %1 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<1519xf64>) -> tensor<1519x3056xf64>
35+
// CHECK-NEXT: %2 = stablehlo.reshape %arg1 : (tensor<1519x3056xf64>) -> tensor<1x1519x3056xf64>
36+
// CHECK-NEXT: %3 = stablehlo.broadcast_in_dim %0, dims = [1] : (tensor<1519xf64>) -> tensor<1x1519x3056xf64>
37+
// CHECK-NEXT: %4 = stablehlo.multiply %3, %2 : tensor<1x1519x3056xf64>
38+
// CHECK-NEXT: return %1, %4 : tensor<1519x3056xf64>, tensor<1x1519x3056xf64>
39+
// CHECK-NEXT: }
40+
}

0 commit comments

Comments
 (0)