@@ -37,7 +37,7 @@ func.func @unaryscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tenso
3737// CHECK-NEXT: return %7 : tensor<1024x1024xf32>
3838// CHECK-NEXT: }
3939
40- func.func @convertscatter (%arg0: tensor <4 xi64 >, %arg1: tensor <6 xi64 >, %arg2: tensor <1024 x1024 xf32 >) -> tensor <1024 x1024 xf32 > {
40+ func.func @expscatter (%arg0: tensor <4 xi64 >, %arg1: tensor <6 xi64 >, %arg2: tensor <1024 x1024 xf32 >) -> tensor <1024 x1024 xf32 > {
4141 %cst = stablehlo.constant dense <2.000000e+00 > : tensor <24 xf32 >
4242 %c = stablehlo.constant dense <1 > : tensor <24 x2 xi64 >
4343 %cst_0 = stablehlo.constant dense <0.000000e+00 > : tensor <1024 x1024 xf32 >
@@ -56,7 +56,7 @@ func.func @convertscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: ten
5656 return %8 : tensor <1024 x1024 xf32 >
5757}
5858
59- // CHECK: func.func @convertscatter (%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
59+ // CHECK: func.func @expscatter (%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
6060// CHECK-NEXT: %cst = stablehlo.constant dense<7.3890562> : tensor<24xf32>
6161// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x1024xf32>
6262// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<24x2xi64>
@@ -73,3 +73,49 @@ func.func @convertscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: ten
7373// CHECK-NEXT: %7 = stablehlo.transpose %6, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
7474// CHECK-NEXT: return %7 : tensor<1024x1024xf32>
7575// CHECK-NEXT: }
76+
77+ func.func @convertscatter (%arg0: tensor <5 x4 xf32 >, %arg1: tensor <5 xui32 >) -> tensor <5 x4 xf32 > {
78+ %c = stablehlo.constant dense <[[4 , 5 ], [4 , 5 ], [4 , 5 ], [4 , 5 ], [4 , 5 ]]> : tensor <5 x2 xi64 >
79+ %c_0 = stablehlo.constant dense <[-1 , 3 , 7 , 11 , 15 ]> : tensor <5 xi64 >
80+ %c_1 = stablehlo.constant dense <true > : tensor <5 xi1 >
81+ %c_2 = stablehlo.constant dense <4 > : tensor <5 xi64 >
82+ %c_3 = stablehlo.constant dense <false > : tensor <4 x5 xi1 >
83+ %0 = stablehlo.transpose %arg0 , dims = [1 , 0 ] : (tensor <5 x4 xf32 >) -> tensor <4 x5 xf32 >
84+ %1 = stablehlo.convert %arg1 : (tensor <5 xui32 >) -> tensor <5 xi64 >
85+ %2 = stablehlo.add %1 , %c_0 : tensor <5 xi64 >
86+ %3 = stablehlo.divide %2 , %c_2 : tensor <5 xi64 >
87+ %4 = stablehlo.reshape %2 : (tensor <5 xi64 >) -> tensor <5 x1 xi64 >
88+ %5 = stablehlo.reshape %3 : (tensor <5 xi64 >) -> tensor <5 x1 xi64 >
89+ %6 = stablehlo.concatenate %4 , %5 , dim = 1 : (tensor <5 x1 xi64 >, tensor <5 x1 xi64 >) -> tensor <5 x2 xi64 >
90+ %7 = stablehlo.remainder %6 , %c : tensor <5 x2 xi64 >
91+ %8 = " stablehlo.scatter" (%c_3 , %7 , %c_1 ) <{scatter_dimension_numbers = #stablehlo.scatter <inserted_window_dims = [0 , 1 ], scatter_dims_to_operand_dims = [0 , 1 ], index _vector_dim = 1 >}> ({
92+ ^bb0 (%arg2: tensor <i1 >, %arg3: tensor <i1 >):
93+ stablehlo.return %arg3 : tensor <i1 >
94+ }) : (tensor <4 x5 xi1 >, tensor <5 x2 xi64 >, tensor <5 xi1 >) -> tensor <4 x5 xi1 >
95+ %9 = stablehlo.convert %8 : (tensor <4 x5 xi1 >) -> tensor <4 x5 xf32 >
96+ %10 = stablehlo.multiply %0 , %9 : tensor <4 x5 xf32 >
97+ %11 = stablehlo.transpose %10 , dims = [1 , 0 ] : (tensor <4 x5 xf32 >) -> tensor <5 x4 xf32 >
98+ return %11 : tensor <5 x4 xf32 >
99+ }
100+
101+ // CHECK: func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tensor<5x4xf32> {
102+ // CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32>
103+ // CHECK-NEXT{LITERAL}: %c = stablehlo.constant dense<[[4, 5], [4, 5], [4, 5], [4, 5], [4, 5]]> : tensor<5x2xi64>
104+ // CHECK-NEXT: %c_0 = stablehlo.constant dense<[-1, 3, 7, 11, 15]> : tensor<5xi64>
105+ // CHECK-NEXT: %c_1 = stablehlo.constant dense<4> : tensor<5xi64>
106+ // CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x4xf32>) -> tensor<4x5xf32>
107+ // CHECK-NEXT: %1 = stablehlo.convert %arg1 : (tensor<5xui32>) -> tensor<5xi64>
108+ // CHECK-NEXT: %2 = stablehlo.add %1, %c_0 : tensor<5xi64>
109+ // CHECK-NEXT: %3 = stablehlo.divide %2, %c_1 : tensor<5xi64>
110+ // CHECK-NEXT: %4 = stablehlo.reshape %2 : (tensor<5xi64>) -> tensor<5x1xi64>
111+ // CHECK-NEXT: %5 = stablehlo.reshape %3 : (tensor<5xi64>) -> tensor<5x1xi64>
112+ // CHECK-NEXT: %6 = stablehlo.concatenate %4, %5, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
113+ // CHECK-NEXT: %7 = stablehlo.remainder %6, %c : tensor<5x2xi64>
114+ // CHECK-NEXT: %8 = "stablehlo.gather"(%0, %7) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<4x5xf32>, tensor<5x2xi64>) -> tensor<5xf32>
115+ // CHECK-NEXT: %9 = "stablehlo.scatter"(%cst, %7, %8) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
116+ // CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
117+ // CHECK-NEXT: stablehlo.return %arg3 : tensor<f32>
118+ // CHECK-NEXT: }) : (tensor<4x5xf32>, tensor<5x2xi64>, tensor<5xf32>) -> tensor<4x5xf32>
119+ // CHECK-NEXT: %10 = stablehlo.transpose %9, dims = [1, 0] : (tensor<4x5xf32>) -> tensor<5x4xf32>
120+ // CHECK-NEXT: return %10 : tensor<5x4xf32>
121+ // CHECK-NEXT: }
0 commit comments