Skip to content

Commit 8a2a739

Browse files
Pad OnlineAttention
This PR makes OnlineAttention derive from IndexingMapOpInterface and make it pad with transform.structured.pad_tiling_interface. Additionally, ensures the dynamic case pads to a constant before tiling and properly canonicalizes to constant shapes once AffineMin simplification kicks in. This requires integrating LLVM past llvm/llvm-project#145068 once it has landed.
1 parent e90c0a5 commit 8a2a739

File tree

3 files changed

+117
-2
lines changed

3 files changed

+117
-2
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,9 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
843843
def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
844844
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
845845
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">,
846-
DestinationStyleOpInterface, LinalgExtInterface,
846+
DestinationStyleOpInterface,
847+
IndexingMapOpInterface,
848+
LinalgExtInterface,
847849
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
848850
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
849851
DeclareOpInterfaceMethods<TilingInterface,
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --verify-diagnostics -canonicalize -cse %s | FileCheck %s
2+
3+
#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
4+
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
5+
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
6+
#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
7+
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
8+
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
9+
10+
// CHECK-LABEL: online_attention
11+
func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x?x64xf32>, %value: tensor<192x?x64xf32>) -> tensor<192x1024x64xf32> {
12+
%scale = arith.constant 1.0 : f32
13+
14+
%output_empty = tensor.empty() : tensor<192x1024x64xf32>
15+
%row_red_empty = tensor.empty() : tensor<192x1024xf32>
16+
17+
%sum_ident = arith.constant 0.000000e+00 : f32
18+
%max_ident = arith.constant -3.40282347E+38 : f32
19+
20+
%output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
21+
%acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
22+
%sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
23+
24+
// CHECK: linalg.iree_linalg_ext.online_attention ins(%{{.*}} : tensor<192x128x64xf32>, tensor<192x128x64xf32>, tensor<192x128x64xf32>, f32)
25+
%out:3 = iree_linalg_ext.online_attention
26+
{ indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
27+
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x?x64xf32>, tensor<192x?x64xf32>, f32)
28+
outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
29+
^bb0(%score: f32):
30+
iree_linalg_ext.yield %score: f32
31+
}
32+
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
33+
34+
return %out#0 : tensor<192x1024x64xf32>
35+
}
36+
37+
module attributes { transform.with_named_sequence } {
38+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
39+
%online_attention = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
40+
41+
// Tile then pad should give us a static shape.
42+
// TODO: this currently does not work, FIXME.
43+
%tiled_online_attention, %loops_l1 = transform.structured.tile_using_for %online_attention tile_sizes [0, 0, 0, 128]
44+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
45+
46+
%padded, %pad = transform.structured.pad_tiling_interface %tiled_online_attention to padding_sizes [128] pad_to_multiple_of {
47+
padding_dimensions = [3],
48+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
49+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
50+
51+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
52+
transform.affine.simplify_min_max_affine_ops %func : !transform.any_op
53+
54+
transform.yield
55+
}
56+
}
57+
58+
// -----
59+
60+
61+
#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
62+
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
63+
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
64+
#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
65+
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
66+
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
67+
68+
// CHECK-LABEL: online_attention
69+
func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x?x64xf32>, %value: tensor<192x?x64xf32>) -> tensor<192x1024x64xf32> {
70+
%scale = arith.constant 1.0 : f32
71+
72+
%output_empty = tensor.empty() : tensor<192x1024x64xf32>
73+
%row_red_empty = tensor.empty() : tensor<192x1024xf32>
74+
75+
%sum_ident = arith.constant 0.000000e+00 : f32
76+
%max_ident = arith.constant -3.40282347E+38 : f32
77+
78+
%output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
79+
%acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
80+
%sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
81+
82+
// CHECK: linalg.iree_linalg_ext.online_attention ins(%{{.*}} : tensor<192x128x64xf32>, tensor<192x128x64xf32>, tensor<192x128x64xf32>, f32)
83+
%out:3 = iree_linalg_ext.online_attention
84+
{ indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
85+
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x?x64xf32>, tensor<192x?x64xf32>, f32)
86+
outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
87+
^bb0(%score: f32):
88+
iree_linalg_ext.yield %score: f32
89+
}
90+
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
91+
92+
return %out#0 : tensor<192x1024x64xf32>
93+
}
94+
95+
module attributes { transform.with_named_sequence } {
96+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
97+
%online_attention = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
98+
99+
// Pad then tile should give us a static shape.
100+
%padded, %pad = transform.structured.pad_tiling_interface %online_attention to padding_sizes [128] pad_to_multiple_of {
101+
padding_dimensions = [3],
102+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
103+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
104+
105+
%tiled_online_attention, %loops_l1 = transform.structured.tile_using_for %padded tile_sizes [0, 0, 0, 128]
106+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
107+
108+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
109+
transform.affine.simplify_min_max_affine_ops %func : !transform.any_op
110+
111+
transform.yield
112+
}
113+
}

third_party/llvm-project

0 commit comments

Comments
 (0)