Skip to content

Commit 9b689ed

Browse files
authored
fix: restrict broadcast in dim check for auto-batching (#1444)
1 parent 67dc896 commit 9b689ed

File tree

3 files changed

+102
-3
lines changed

3 files changed

+102
-3
lines changed

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,17 @@ bool ConcatInsertDimToBatchBase::validBroadcastInDimOpInsertDimForBatching(
380380
return false;
381381

382382
// If concat dim is present in broadcast dims, then it is not a valid insert
383-
for (auto bDim : broadcastInDimOp.getBroadcastDimensions()) {
383+
auto broadcastInDimOpDims =
384+
llvm::to_vector(broadcastInDimOp.getBroadcastDimensions());
385+
for (auto bDim : broadcastInDimOpDims) {
384386
if (bDim == dim)
385387
return false;
386388
}
387389

390+
// Broadcast dims must be sorted
391+
if (!llvm::is_sorted(broadcastInDimOpDims))
392+
return false;
393+
388394
// insert dim must be of size 1
389395
return outputType.getShape()[dim] == 1;
390396
}

test/lit_tests/autobatching/concatbcastdotgeneral.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ module @reactant_updates2 attributes {mhlo.num_partitions = 1 : i64, mhlo.num_re
4040

4141
// CHECK: module @reactant_updates2 attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
4242
// CHECK-NEXT: func.func @main(%arg0: tensor<3x32x32xf32>, %arg1: tensor<3x32x32xf32>, %arg2: tensor<3x32x32xf32>) -> (tensor<3x32x32xf32>, tensor<3x32x32xf32>, tensor<3x32x32xf32>) {
43-
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [1] x [2], precision = [DEFAULT, DEFAULT] : (tensor<3x32x32xf32>, tensor<3x32x32xf32>) -> tensor<3x32x32xf32>
44-
// CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %arg2, batching_dims = [0] x [0], contracting_dims = [1] x [2], precision = [DEFAULT, DEFAULT] : (tensor<3x32x32xf32>, tensor<3x32x32xf32>) -> tensor<3x32x32xf32>
43+
// CHECK-NEXT: %0 = stablehlo.dot_general %arg1, %arg0, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x32x32xf32>, tensor<3x32x32xf32>) -> tensor<3x32x32xf32>
44+
// CHECK-NEXT: %1 = stablehlo.dot_general %arg2, %arg0, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x32x32xf32>, tensor<3x32x32xf32>) -> tensor<3x32x32xf32>
4545
// CHECK-NEXT: return %arg0, %0, %1 : tensor<3x32x32xf32>, tensor<3x32x32xf32>, tensor<3x32x32xf32>
4646
// CHECK-NEXT: }
4747
// CHECK-NEXT: }
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// RUN: enzymexlamlir-opt --pass-pipeline="any(enzyme-hlo-generate-td{patterns=concat_insert_dim_elementwise},transform-interpreter,enzyme-hlo-remove-transform)"
2+
3+
module {
4+
func.func @mapped_sub(%arg0: tensor<3x5x10xf32>, %arg1: tensor<3x5x10xf32>) -> (tensor<5x3x10xf32>, tensor<3x5x10xf32>, tensor<3x5x10xf32>) {
5+
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32>
6+
%1 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32>
7+
%2 = stablehlo.slice %0 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
8+
%3 = stablehlo.transpose %2, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
9+
%4 = stablehlo.reshape %3 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
10+
%5 = stablehlo.transpose %4, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
11+
%6 = stablehlo.convert %5 : tensor<10x3xf32>
12+
%7 = stablehlo.broadcast_in_dim %6, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
13+
%8 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
14+
%9 = stablehlo.slice %1 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
15+
%10 = stablehlo.transpose %9, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
16+
%11 = stablehlo.reshape %10 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
17+
%12 = stablehlo.transpose %11, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
18+
%13 = stablehlo.convert %12 : tensor<10x3xf32>
19+
%14 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
20+
%15 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
21+
%16 = stablehlo.subtract %8, %15 : tensor<10x3xf32>
22+
%17 = stablehlo.slice %0 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
23+
%18 = stablehlo.transpose %17, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
24+
%19 = stablehlo.reshape %18 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
25+
%20 = stablehlo.transpose %19, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
26+
%21 = stablehlo.convert %20 : tensor<10x3xf32>
27+
%22 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
28+
%23 = stablehlo.broadcast_in_dim %22, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
29+
%24 = stablehlo.slice %1 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
30+
%25 = stablehlo.transpose %24, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
31+
%26 = stablehlo.reshape %25 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
32+
%27 = stablehlo.transpose %26, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
33+
%28 = stablehlo.convert %27 : tensor<10x3xf32>
34+
%29 = stablehlo.broadcast_in_dim %28, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
35+
%30 = stablehlo.broadcast_in_dim %29, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
36+
%31 = stablehlo.subtract %23, %30 : tensor<10x3xf32>
37+
%32 = stablehlo.slice %0 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
38+
%33 = stablehlo.transpose %32, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
39+
%34 = stablehlo.reshape %33 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
40+
%35 = stablehlo.transpose %34, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
41+
%36 = stablehlo.convert %35 : tensor<10x3xf32>
42+
%37 = stablehlo.broadcast_in_dim %36, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
43+
%38 = stablehlo.broadcast_in_dim %37, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
44+
%39 = stablehlo.slice %1 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
45+
%40 = stablehlo.transpose %39, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
46+
%41 = stablehlo.reshape %40 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
47+
%42 = stablehlo.transpose %41, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
48+
%43 = stablehlo.convert %42 : tensor<10x3xf32>
49+
%44 = stablehlo.broadcast_in_dim %43, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
50+
%45 = stablehlo.broadcast_in_dim %44, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
51+
%46 = stablehlo.subtract %38, %45 : tensor<10x3xf32>
52+
%47 = stablehlo.slice %0 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
53+
%48 = stablehlo.transpose %47, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
54+
%49 = stablehlo.reshape %48 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
55+
%50 = stablehlo.transpose %49, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
56+
%51 = stablehlo.convert %50 : tensor<10x3xf32>
57+
%52 = stablehlo.broadcast_in_dim %51, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
58+
%53 = stablehlo.broadcast_in_dim %52, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
59+
%54 = stablehlo.slice %1 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
60+
%55 = stablehlo.transpose %54, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
61+
%56 = stablehlo.reshape %55 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
62+
%57 = stablehlo.transpose %56, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
63+
%58 = stablehlo.convert %57 : tensor<10x3xf32>
64+
%59 = stablehlo.broadcast_in_dim %58, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
65+
%60 = stablehlo.broadcast_in_dim %59, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
66+
%61 = stablehlo.subtract %53, %60 : tensor<10x3xf32>
67+
%62 = stablehlo.slice %0 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
68+
%63 = stablehlo.transpose %62, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
69+
%64 = stablehlo.reshape %63 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
70+
%65 = stablehlo.transpose %64, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
71+
%66 = stablehlo.convert %65 : tensor<10x3xf32>
72+
%67 = stablehlo.broadcast_in_dim %66, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
73+
%68 = stablehlo.broadcast_in_dim %67, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
74+
%69 = stablehlo.slice %1 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
75+
%70 = stablehlo.transpose %69, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32>
76+
%71 = stablehlo.reshape %70 : (tensor<3x1x10xf32>) -> tensor<3x10xf32>
77+
%72 = stablehlo.transpose %71, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
78+
%73 = stablehlo.convert %72 : tensor<10x3xf32>
79+
%74 = stablehlo.broadcast_in_dim %73, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
80+
%75 = stablehlo.broadcast_in_dim %74, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32>
81+
%76 = stablehlo.subtract %68, %75 : tensor<10x3xf32>
82+
%77 = stablehlo.broadcast_in_dim %16, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32>
83+
%78 = stablehlo.broadcast_in_dim %31, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32>
84+
%79 = stablehlo.broadcast_in_dim %46, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32>
85+
%80 = stablehlo.broadcast_in_dim %61, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32>
86+
%81 = stablehlo.broadcast_in_dim %76, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32>
87+
%82 = stablehlo.concatenate %77, %78, %79, %80, %81, dim = 0 : (tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>) -> tensor<5x3x10xf32>
88+
// CHECK: stablehlo.concatenate
89+
%83 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<10x5x3xf32>) -> tensor<3x5x10xf32>
90+
%84 = stablehlo.transpose %1, dims = [2, 1, 0] : (tensor<10x5x3xf32>) -> tensor<3x5x10xf32>
91+
return %82, %83, %84 : tensor<5x3x10xf32>, tensor<3x5x10xf32>, tensor<3x5x10xf32>
92+
}
93+
}

0 commit comments

Comments
 (0)