Skip to content

Commit cd10cc0

Browse files
authored
Fix update (#1849)
1 parent 30eb4e7 commit cd10cc0

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,7 +2361,9 @@ struct DUSDUSConcat final
23612361
idxs[0] < idxs[1] + tys[1].getShape()[diffidx] &&
23622362
idxs[0] + tys[0].getShape()[diffidx] >
23632363
idxs[1] + tys[1].getShape()[diffidx] &&
2364-
allStatic) {
2364+
allStatic &&
2365+
idxs[0] + tys[0].getShape()[diffidx] <=
2366+
dus.getOperand().getType().getShape()[diffidx]) {
23652367
// the new update overlaps, following the old update
23662368

23672369
// Case 5:
@@ -2386,7 +2388,9 @@ struct DUSDUSConcat final
23862388
idxs[1] < idxs[0] + tys[0].getShape()[diffidx] &&
23872389
idxs[0] + tys[0].getShape()[diffidx] <
23882390
idxs[1] + tys[1].getShape()[diffidx] &&
2389-
allStatic) {
2391+
allStatic &&
2392+
idxs[1] + tys[1].getShape()[diffidx] <=
2393+
dus.getOperand().getType().getShape()[diffidx]) {
23902394
// the new update overlaps, following the old update
23912395

23922396
// Case 5:

test/lit_tests/dusdus2.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt --split-input-file | FileCheck %s
2+
3+
module {
4+
func.func @fuse(%op: tensor<4x1x8x8xcomplex<f32>>, %19: tensor<4x1x4x8xcomplex<f32>>, %22: tensor<4x1x4x8xcomplex<f32>>) -> tensor<4x1x8x8xcomplex<f32>> {
5+
%c_7 = stablehlo.constant dense<4> : tensor<i32>
6+
%c_6 = stablehlo.constant dense<0> : tensor<i32>
7+
%23 = stablehlo.dynamic_update_slice %op, %19, %c_6, %c_6, %c_6, %c_6 : (tensor<4x1x8x8xcomplex<f32>>, tensor<4x1x4x8xcomplex<f32>>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4x1x8x8xcomplex<f32>>
8+
%26 = stablehlo.dynamic_update_slice %23, %22, %c_6, %c_6, %c_6, %c_7 : (tensor<4x1x8x8xcomplex<f32>>, tensor<4x1x4x8xcomplex<f32>>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4x1x8x8xcomplex<f32>>
9+
func.return %26: tensor<4x1x8x8xcomplex<f32>>
10+
}
11+
}
12+
13+
// CHECK: func.func @fuse(%arg0: tensor<4x1x8x8xcomplex<f32>>, %arg1: tensor<4x1x4x8xcomplex<f32>>, %arg2: tensor<4x1x4x8xcomplex<f32>>) -> tensor<4x1x8x8xcomplex<f32>> {
14+
// CHECK-NEXT: %c = stablehlo.constant dense<4> : tensor<i32>
15+
// CHECK-NEXT: %c_0 = stablehlo.constant dense<0> : tensor<i32>
16+
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:4, 0:1, 4:8, 0:8] : (tensor<4x1x8x8xcomplex<f32>>) -> tensor<4x1x4x8xcomplex<f32>>
17+
// CHECK-NEXT: %1 = stablehlo.dynamic_update_slice %arg1, %arg2, %c_0, %c_0, %c_0, %c : (tensor<4x1x4x8xcomplex<f32>>, tensor<4x1x4x8xcomplex<f32>>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4x1x4x8xcomplex<f32>>
18+
// CHECK-NEXT: %2 = stablehlo.concatenate %1, %0, dim = 2 : (tensor<4x1x4x8xcomplex<f32>>, tensor<4x1x4x8xcomplex<f32>>) -> tensor<4x1x8x8xcomplex<f32>>
19+
// CHECK-NEXT: return %2 : tensor<4x1x8x8xcomplex<f32>>
20+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)