|
| 1 | +// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch %s | FileCheck %s |
| 2 | +// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch --enzyme-batch-to-tensor %s | FileCheck %s --check-prefix=LEGAL |
| 3 | + |
| 4 | +//1. Scalar test |
| 5 | +module { |
| 6 | + func.func @square(%x : tensor<f64>) -> tensor<f64>{ |
| 7 | + %y = stablehlo.multiply %x, %x : tensor<f64> |
| 8 | + return %y : tensor<f64> |
| 9 | + } |
| 10 | + func.func @test1(%x : tensor<f64>, %dr1 : tensor<f64>, %dr2 : tensor<f64>) -> (tensor<f64>,tensor<f64>) { |
| 11 | + %r, %dx1 = enzyme.autodiff @square(%x, %dr1) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>] } : (tensor<f64>, tensor<f64>) -> (tensor<f64>,tensor<f64>) |
| 12 | + %r2, %dx2 = enzyme.autodiff @square(%x, %dr2) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>] } : (tensor<f64>, tensor<f64>) -> (tensor<f64>,tensor<f64>) |
| 13 | + return %dx1,%dx2 : tensor<f64>, tensor<f64> |
| 14 | + } |
| 15 | +} |
| 16 | + |
| 17 | +// CHECK-LABEL: func.func @test1 |
| 18 | +// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>) |
| 19 | +// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<f64>, tensor<f64>) -> tensor<2xf64> |
| 20 | +// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>) |
| 21 | +// CHECK: %[[C0:.*]] = arith.constant 0 : index |
| 22 | +// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : (tensor<2xf64>) -> tensor<f64> |
| 23 | +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index |
| 24 | +// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : (tensor<2xf64>) -> tensor<f64> |
| 25 | +// CHECK-NEXT: return %[[RES0]], %[[RES1]] |
| 26 | + |
| 27 | +// LEGAL-LABEL: func.func @test1 |
| 28 | +// LEGAL-SAME: (%[[PRIMAL:.*]]: f64, %[[DIFF1:.*]]: f64, %[[DIFF2:.*]]: f64) -> (f64, f64) |
| 29 | +// LEGAL: %[[CONCAT:.*]] = tensor.from_elements %[[DIFF1]], %[[DIFF2]] : tensor<2xf64> |
| 30 | +// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (f64, tensor<2xf64>) -> (f64, tensor<2xf64>) |
| 31 | +// LEGAL: %[[C0:.*]] = arith.constant 0 : index |
| 32 | +// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : tensor<2xf64> |
| 33 | +// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index |
| 34 | +// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : tensor<2xf64> |
| 35 | +// LEGAL-NEXT: return %[[RES0]], %[[RES1]] |
| 36 | + |
| 37 | +// ----- |
| 38 | + |
| 39 | +//2. Tensor test |
| 40 | +module { |
| 41 | + func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ |
| 42 | + %y = stablehlo.multiply %x, %x : tensor<10xf64> |
| 43 | + return %y : tensor<10xf64> |
| 44 | + } |
| 45 | + func.func @test2(%x : tensor<10xf64>, %dr1 : tensor<10xf64>, %dr2 : tensor<10xf64>) -> (tensor<10xf64>,tensor<10xf64>) { |
| 46 | + %r, %dx1 = enzyme.autodiff @square(%x, %dr1) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) |
| 47 | + %r2, %dx2 = enzyme.autodiff @square(%x, %dr2) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) |
| 48 | + return %dx1,%dx2 : tensor<10xf64>,tensor<10xf64> |
| 49 | + } |
| 50 | +} |
| 51 | + |
| 52 | + |
| 53 | +// CHECK-LABEL: func.func @test2 |
| 54 | +// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) |
| 55 | +// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64> |
| 56 | +// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>) |
| 57 | +// CHECK: %[[C0:.*]] = arith.constant 0 : index |
| 58 | +// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : (tensor<2x10xf64>) -> tensor<10xf64> |
| 59 | +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index |
| 60 | +// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : (tensor<2x10xf64>) -> tensor<10xf64> |
| 61 | +// CHECK-NEXT: return %[[RES0]], %[[RES1]] |
| 62 | + |
| 63 | +// LEGAL-LABEL: func.func @test2 |
| 64 | +// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) |
| 65 | +// LEGAL: %[[EDIFF1:.*]] = tensor.expand_shape %[[DIFF1]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64> |
| 66 | +// LEGAL: %[[EDIFF2:.*]] = tensor.expand_shape %[[DIFF2]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64> |
| 67 | +// LEGAL: %[[CONCAT:.*]] = tensor.concat dim(0) %[[EDIFF1]], %[[EDIFF2]] : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64> |
| 68 | +// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>) |
| 69 | +// LEGAL: %[[C0:.*]] = arith.constant 0 : index |
| 70 | +// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract_slice %[[BATCHED_RES_BASE]]#1[%[[C0]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64> |
| 71 | +// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index |
| 72 | +// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract_slice %[[BATCHED_RES_BASE]]#1[%[[C1]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64> |
| 73 | +// LEGAL-NEXT: return %[[RES0]], %[[RES1]] |
0 commit comments