Skip to content

Commit 667d6dc

Browse files
Pad OnlineAttention
This PR makes OnlineAttention derive from IndexingMapOpInterface and make it pad with transform.structured.pad_tiling_interface. Additionally, ensures the dynamic case pads to a constant before tiling and properly canonicalizes to constant shapes once AffineMin simplification kicks in. This requires integrating LLVM past llvm/llvm-project#145068 once it has landed.
1 parent e90c0a5 commit 667d6dc

File tree

3 files changed

+116
-2
lines changed

3 files changed

+116
-2
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,9 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
843843
def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
844844
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
845845
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">,
846-
DestinationStyleOpInterface, LinalgExtInterface,
846+
DestinationStyleOpInterface,
847+
IndexingMapOpInterface,
848+
LinalgExtInterface,
847849
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
848850
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
849851
DeclareOpInterfaceMethods<TilingInterface,
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --verify-diagnostics -canonicalize -cse %s | FileCheck %s
2+
3+
#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
4+
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
5+
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
6+
#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
7+
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
8+
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
9+
10+
// CHECK-LABEL: online_attention
11+
func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x?x64xf32>, %value: tensor<192x?x64xf32>) -> tensor<192x1024x64xf32> {
12+
%scale = arith.constant 1.0 : f32
13+
14+
%output_empty = tensor.empty() : tensor<192x1024x64xf32>
15+
%row_red_empty = tensor.empty() : tensor<192x1024xf32>
16+
17+
%sum_ident = arith.constant 0.000000e+00 : f32
18+
%max_ident = arith.constant -3.40282347E+38 : f32
19+
20+
%output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
21+
%acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
22+
%sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
23+
24+
// CHECK: linalg.iree_linalg_ext.online_attention ins(%{{.*}} : tensor<192x128x64xf32>, tensor<192x128x64xf32>, tensor<192x128x64xf32>, f32)
25+
%out:3 = iree_linalg_ext.online_attention
26+
{ indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
27+
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x?x64xf32>, tensor<192x?x64xf32>, f32)
28+
outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
29+
^bb0(%score: f32):
30+
iree_linalg_ext.yield %score: f32
31+
}
32+
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
33+
34+
return %out#0 : tensor<192x1024x64xf32>
35+
}
36+
37+
module attributes { transform.with_named_sequence } {
38+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
39+
%online_attention = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
40+
41+
// Tile then pad should give us a static shape.
42+
%tiled_online_attention, %loops_l1 = transform.structured.tile_using_for %online_attention tile_sizes [0, 0, 0, 128]
43+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
44+
45+
%padded, %pad = transform.structured.pad_tiling_interface %tiled_online_attention to padding_sizes [128] pad_to_multiple_of {
46+
padding_dimensions = [3],
47+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
48+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
49+
50+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
51+
transform.affine.simplify_min_max_affine_ops %func : !transform.any_op
52+
53+
transform.yield
54+
}
55+
}
56+
57+
// -----
58+
59+
60+
#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
61+
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
62+
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
63+
#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
64+
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
65+
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
66+
67+
// CHECK-LABEL: online_attention
68+
func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x?x64xf32>, %value: tensor<192x?x64xf32>) -> tensor<192x1024x64xf32> {
69+
%scale = arith.constant 1.0 : f32
70+
71+
%output_empty = tensor.empty() : tensor<192x1024x64xf32>
72+
%row_red_empty = tensor.empty() : tensor<192x1024xf32>
73+
74+
%sum_ident = arith.constant 0.000000e+00 : f32
75+
%max_ident = arith.constant -3.40282347E+38 : f32
76+
77+
%output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
78+
%acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
79+
%sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
80+
81+
// CHECK: linalg.iree_linalg_ext.online_attention ins(%{{.*}} : tensor<192x128x64xf32>, tensor<192x128x64xf32>, tensor<192x128x64xf32>, f32)
82+
%out:3 = iree_linalg_ext.online_attention
83+
{ indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
84+
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x?x64xf32>, tensor<192x?x64xf32>, f32)
85+
outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
86+
^bb0(%score: f32):
87+
iree_linalg_ext.yield %score: f32
88+
}
89+
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
90+
91+
return %out#0 : tensor<192x1024x64xf32>
92+
}
93+
94+
module attributes { transform.with_named_sequence } {
95+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
96+
%online_attention = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
97+
98+
// Pad then tile should give us a static shape.
99+
%padded, %pad = transform.structured.pad_tiling_interface %online_attention to padding_sizes [128] pad_to_multiple_of {
100+
padding_dimensions = [3],
101+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
102+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
103+
104+
%tiled_online_attention, %loops_l1 = transform.structured.tile_using_for %padded tile_sizes [0, 0, 0, 128]
105+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
106+
107+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
108+
transform.affine.simplify_min_max_affine_ops %func : !transform.any_op
109+
110+
transform.yield
111+
}
112+
}

third_party/llvm-project

0 commit comments

Comments
 (0)