@@ -18,22 +18,20 @@ module {
1818// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
1919// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<f64>, tensor<f64>) -> tensor<2xf64>
2020// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
21- // CHECK: %[[C0:.*]] = arith.constant 0 : index
22- // CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : (tensor<2xf64>) -> tensor<f64>
23- // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
24- // CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : (tensor<2xf64>) -> tensor<f64>
21+ // CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[0] : (tensor<2xf64>) -> tensor<f64>
22+ // CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[1] : (tensor<2xf64>) -> tensor<f64>
2523// CHECK-NEXT: return %[[RES0]], %[[RES1]]
2624
2725// LEGAL-LABEL: func.func @test1
28- // LEGAL-SAME: (%[[PRIMAL:.*]]: f64, %[[DIFF1:.*]]: f64, %[[DIFF2:.*]]: f64) -> (f64, f64)
26+ // LEGAL-SAME: (%[[PRIMAL:.*]]: tensor< f64> , %[[DIFF1:.*]]: tensor< f64> , %[[DIFF2:.*]]: tensor< f64> ) -> (tensor< f64>, tensor< f64> )
2927// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<f64>) -> tensor<1xf64>
3028// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<f64>) -> tensor<1xf64>
3129// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1xf64>, tensor<1xf64>) -> tensor<2xf64>
32- // LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (f64, tensor<2xf64>) -> (f64, tensor<2xf64>)
33- // LEGAL: %[[C0 :.*]] = arith.constant 0 : index
34- // LEGAL-NEXT: %[[RES0:.*]] = tensor.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : tensor<2xf64 >
35- // LEGAL-NEXT: %[[C1 :.*]] = arith.constant 1 : index
36- // LEGAL-NEXT: %[[RES1:.*]] = tensor.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : tensor<2xf64 >
30+ // LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor< f64> , tensor<2xf64>) -> (tensor< f64> , tensor<2xf64>)
31+ // LEGAL: %[[R0 :.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
32+ // LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : ( tensor<1xf64>) -> tensor<f64 >
33+ // LEGAL-NEXT: %[[R1 :.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64>
34+ // LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : ( tensor<1xf64>) -> tensor<f64 >
3735// LEGAL-NEXT: return %[[RES0]], %[[RES1]]
3836
3937// -----
@@ -56,10 +54,8 @@ module {
5654// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
5755// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64>
5856// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>)
59- // CHECK: %[[C0:.*]] = arith.constant 0 : index
60- // CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : (tensor<2x10xf64>) -> tensor<10xf64>
61- // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
62- // CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : (tensor<2x10xf64>) -> tensor<10xf64>
57+ // CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[0] : (tensor<2x10xf64>) -> tensor<10xf64>
58+ // CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[1] : (tensor<2x10xf64>) -> tensor<10xf64>
6359// CHECK-NEXT: return %[[RES0]], %[[RES1]]
6460
6561// LEGAL-LABEL: func.func @test2
@@ -68,8 +64,8 @@ module {
6864// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<10xf64>) -> tensor<1x10xf64>
6965// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
7066// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>)
71- // LEGAL: %[[C0 :.*]] = arith.constant 0 : index
72- // LEGAL-NEXT: %[[RES0:.*]] = tensor.extract_slice %[[BATCHED_RES_BASE]]#1[%[[C0]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
73- // LEGAL-NEXT: %[[C1 :.*]] = arith.constant 1 : index
74- // LEGAL-NEXT: %[[RES1:.*]] = tensor.extract_slice %[[BATCHED_RES_BASE]]#1[%[[C1]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
67+ // LEGAL: %[[R0 :.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [0:1, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
68+ // LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : ( tensor<1x10xf64>) -> tensor<10xf64>
69+ // LEGAL-NEXT: %[[R1 :.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [1:2, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
70+ // LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : ( tensor<1x10xf64>) -> tensor<10xf64>
7571// LEGAL-NEXT: return %[[RES0]], %[[RES1]]
0 commit comments