Skip to content

Commit af2bfe5

Browse files
committed
fix lit tests for concatenateOp
1 parent bb22ea2 commit af2bfe5

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

test/lit_tests/adbatching/bwd_batch.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ module {
2626

2727
// LEGAL-LABEL: func.func @test1
2828
// LEGAL-SAME: (%[[PRIMAL:.*]]: f64, %[[DIFF1:.*]]: f64, %[[DIFF2:.*]]: f64) -> (f64, f64)
29-
// LEGAL: %[[CONCAT:.*]] = tensor.from_elements %[[DIFF1]], %[[DIFF2]] : tensor<2xf64>
29+
// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<f64>) -> tensor<1xf64>
30+
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<f64>) -> tensor<1xf64>
31+
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1xf64>, tensor<1xf64>) -> tensor<2xf64>
3032
// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (f64, tensor<2xf64>) -> (f64, tensor<2xf64>)
3133
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
3234
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : tensor<2xf64>
@@ -62,9 +64,9 @@ module {
6264

6365
// LEGAL-LABEL: func.func @test2
6466
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
65-
// LEGAL: %[[EDIFF1:.*]] = tensor.expand_shape %[[DIFF1]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64>
66-
// LEGAL: %[[EDIFF2:.*]] = tensor.expand_shape %[[DIFF2]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64>
67-
// LEGAL: %[[CONCAT:.*]] = tensor.concat dim(0) %[[EDIFF1]], %[[EDIFF2]] : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
67+
// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<10xf64>) -> tensor<1x10xf64>
68+
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<10xf64>) -> tensor<1x10xf64>
69+
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
6870
// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>)
6971
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
7072
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract_slice %[[BATCHED_RES_BASE]]#1[%[[C0]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>

test/lit_tests/adbatching/fwd_batch.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ module {
2525

2626
// LEGAL-LABEL: func.func @test1
2727
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
28-
// LEGAL: %[[CONCAT:.*]] = tensor.from_elements %[[DIFF1]], %[[DIFF2]] : tensor<2xf64>
28+
// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<f64>) -> tensor<1xf64>
29+
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<f64>) -> tensor<1xf64>
30+
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1xf64>, tensor<1xf64>) -> tensor<2xf64>
2931
// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> tensor<2xf64>
3032
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
3133
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract %[[BATCHED_RES]][%[[C0]]] : tensor<2xf64>
@@ -60,9 +62,9 @@ module {
6062

6163
// LEGAL-LABEL: func.func @test2
6264
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
63-
// LEGAL: %[[EDIFF1:.*]] = tensor.expand_shape %[[DIFF1]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64>
64-
// LEGAL: %[[EDIFF2:.*]] = tensor.expand_shape %[[DIFF2]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64>
65-
// LEGAL: %[[CONCAT:.*]] = tensor.concat dim(0) %[[EDIFF1]], %[[EDIFF2]] : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
65+
// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<10xf64>) -> tensor<1x10xf64>
66+
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<10xf64>) -> tensor<1x10xf64>
67+
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
6668
// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
6769
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
6870
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract_slice %[[BATCHED_RES]][%[[C0]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>

0 commit comments

Comments
 (0)