Skip to content

Commit 5ba2b5c

Browse files
committed
Fix dynamic shape pad
1 parent 11cef65 commit 5ba2b5c

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2077,7 +2077,7 @@ tryRaisingOpToStableHLO(Operation *op, IRMapping &mapping, OpBuilder &builder,
20772077
int64_t pLow = 0;
20782078
int64_t pHigh = 0;
20792079

2080-
if (hasRange) {
2080+
if (hasRange && sz != ShapedType::kDynamic) {
20812081
if (start < 0) {
20822082
pLow = -start;
20832083
start = 0;
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: enzymexlamlir-opt --raise-affine-to-stablehlo --split-input-file %s | FileCheck %s
2+
3+
module {
4+
func.func @h(%62: !llvm.ptr) -> () {
5+
%c1 = arith.constant 1 : index
6+
%c256 = arith.constant 256 : index
7+
%c39063 = arith.constant 39063 : index
8+
%cst_8 = arith.constant 0.000000e+00 : f32
9+
%100 = "enzymexla.gpu_wrapper"(%c39063, %c1, %c1, %c256, %c1, %c1) ({
10+
affine.parallel (%arg0) = (0) to (10000000) {
11+
%234 = "enzymexla.pointer2memref"(%62) : (!llvm.ptr) -> memref<?xf32>
12+
affine.store %cst_8, %234[%arg0] : memref<?xf32>
13+
}
14+
"enzymexla.polygeist_yield"() : () -> ()
15+
}) : (index, index, index, index, index, index) -> index
16+
func.return
17+
}
18+
}
19+
20+
// CHECK: func.func @h(%arg0: !llvm.ptr) {
21+
// CHECK-NEXT: %c1 = arith.constant 1 : index
22+
// CHECK-NEXT: %c256 = arith.constant 256 : index
23+
// CHECK-NEXT: %c39063 = arith.constant 39063 : index
24+
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
25+
// CHECK-NEXT: %0 = "enzymexla.pointer2memref"(%arg0) : (!llvm.ptr) -> memref<?xf32>
26+
// CHECK-NEXT: enzymexla.xla_wrapper @rxla$raised_0 (%0) : (memref<?xf32>) -> ()
27+
// CHECK-NEXT: return
28+
// CHECK-NEXT: }
29+
// CHECK: func.func private @rxla$raised_0(%arg0: tensor<?xf32>) -> tensor<?xf32> {
30+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
31+
// CHECK-NEXT: %0 = stablehlo.iota dim = 0 : tensor<10000000xi64>
32+
// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor<10000000xi64>
33+
// CHECK-NEXT: %1 = stablehlo.add %0, %c : tensor<10000000xi64>
34+
// CHECK-NEXT: %c_0 = stablehlo.constant dense<1> : tensor<10000000xi64>
35+
// CHECK-NEXT: %2 = stablehlo.multiply %1, %c_0 : tensor<10000000xi64>
36+
// CHECK-NEXT: %c_1 = stablehlo.constant dense<0> : tensor<i64>
37+
// CHECK-NEXT: %3 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<10000000xf32>
38+
// CHECK-NEXT: %4 = stablehlo.dynamic_update_slice %arg0, %3, %c_1 : (tensor<?xf32>, tensor<10000000xf32>, tensor<i64>) -> tensor<?xf32>
39+
// CHECK-NEXT: return %4 : tensor<?xf32>
40+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)