Skip to content

Commit 1110ac1

Browse files
Groverkssqedawkins
andauthored
[VectorDistribute] Implement layout analysis for transfer_gather (#21164)
Also improves the implementation of mask layout inference for transfer_gather operations. --------- Co-authored-by: Quinn Dawkins <[email protected]>
1 parent 278e249 commit 1110ac1

File tree

4 files changed

+219
-36
lines changed

4 files changed

+219
-36
lines changed

compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -845,30 +845,15 @@ static void enforceLayoutToTransferReadOp(
845845
return;
846846
}
847847

848-
// Build a transposed layout.
849-
SmallVector<unsigned> permutation;
850-
AffineMap permMap = read.getPermutationMap();
851-
bool isSupportedPerm =
852-
permMap.isPermutationOfMinorIdentityWithBroadcasting(permutation);
853-
VectorLayoutInterface layout = result->getLayout();
854-
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
855-
if (isSupportedPerm) {
856-
layout = layout.permute(transposePerm);
857-
AffineMap toMinorIdentity =
858-
AffineMap::getPermutationMap(permutation, permMap.getContext());
859-
AffineMap orderedMap = toMinorIdentity.compose(permMap);
860-
SmallVector<bool> droppedDims(layout.getRank(), false);
861-
for (unsigned bdim : orderedMap.getBroadcastDims()) {
862-
droppedDims[bdim] = true;
863-
}
864-
layout = layout.project(droppedDims);
848+
DistributionLayout *maskLattice = operandLattices[0];
865849

866-
for (auto [index, operandLattice] : llvm::enumerate(operandLattices)) {
867-
ChangeResult changed = operandLattice->resolveWithPossibleConflict(
868-
layout, getOpOperand(read, index));
869-
update(operandLattice, changed);
870-
}
871-
}
850+
VectorLayoutInterface layout = result->getLayout();
851+
AffineMap maskMap =
852+
inversePermutation(compressUnusedDims(read.getPermutationMap()));
853+
VectorLayoutInterface maskLayout = layout.apply(maskMap);
854+
ChangeResult changed = maskLattice->resolveWithPossibleConflict(
855+
maskLayout, getOpOperand(read, 0));
856+
update(maskLattice, changed);
872857
}
873858

874859
static void enforceLayoutToTransferWriteOp(
@@ -890,22 +875,56 @@ static void enforceLayoutToTransferWriteOp(
890875
return;
891876
}
892877

893-
// Build a transposed layout.
894-
SmallVector<unsigned> permutation;
895-
AffineMap permMap = write.getPermutationMap();
896-
bool isSupportedPerm =
897-
permMap.isPermutationOfMinorIdentityWithBroadcasting(permutation);
878+
DistributionLayout *maskLattice = operandLattices[1];
879+
898880
VectorLayoutInterface layout = writeOperand->getLayout();
899-
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
900-
if (isSupportedPerm) {
901-
layout = layout.permute(transposePerm);
881+
AffineMap maskMap =
882+
inversePermutation(compressUnusedDims(write.getPermutationMap()));
883+
VectorLayoutInterface maskLayout = layout.apply(maskMap);
884+
ChangeResult changed = maskLattice->resolveWithPossibleConflict(
885+
maskLayout, getOpOperand(write, 1));
886+
update(maskLattice, changed);
887+
}
888+
889+
static void enforceLayoutToTransferGatherOp(
890+
TransferGatherOp gather, ArrayRef<DistributionLayout *> operandLattices,
891+
ArrayRef<const DistributionLayout *> resultLattices,
892+
std::function<void(DistributionLayout *, ChangeResult)> update) {
893+
if (resultLattices.empty()) {
894+
return;
895+
}
896+
897+
// transfer_gather has only one vector result.
898+
const DistributionLayout *result = resultLattices[0];
899+
// Cannot enforce layout if result is uninitialized.
900+
if (result->isUninitialized()) {
901+
return;
902+
}
903+
VectorLayoutInterface layout = result->getLayout();
904+
905+
ArrayRef<DistributionLayout *> indexVecLattices =
906+
operandLattices.slice(0, gather.getIndexVecs().size());
907+
AffineMap sourceMap =
908+
inverseAndBroadcastProjectedPermutation(gather.getPermutationMap());
909+
VectorLayoutInterface sourceLayout = layout.apply(sourceMap);
910+
for (auto [i, lattice, operand] :
911+
llvm::enumerate(indexVecLattices, gather.getIndexVecsMutable())) {
912+
AffineMap indexVecMap = gather.getIndexedMapsArray()[i];
913+
VectorLayoutInterface indexVecLayout = sourceLayout.apply(indexVecMap);
914+
ChangeResult changed =
915+
lattice->resolveWithPossibleConflict(indexVecLayout, operand);
916+
update(lattice, changed);
902917
}
903918

904-
for (auto [index, operandLattice] :
905-
llvm::enumerate(operandLattices.slice(1))) {
906-
ChangeResult changed = operandLattice->resolveWithPossibleConflict(
907-
layout, getOpOperand(write, index + 1));
908-
update(operandLattice, changed);
919+
if (gather.getMask()) {
920+
DistributionLayout *maskLattice =
921+
operandLattices[gather.getIndexVecs().size()];
922+
AffineMap maskMap =
923+
inversePermutation(compressUnusedDims(gather.getPermutationMap()));
924+
VectorLayoutInterface maskLayout = layout.apply(maskMap);
925+
ChangeResult changed = maskLattice->resolveWithPossibleConflict(
926+
maskLayout, gather.getMaskMutable()[0]);
927+
update(maskLattice, changed);
909928
}
910929
}
911930

@@ -964,6 +983,12 @@ void enforcementTransferFunction(
964983
update);
965984
return;
966985
}
986+
987+
if (auto gather = dyn_cast<TransferGatherOp>(op)) {
988+
enforceLayoutToTransferGatherOp(gather, operandLattices, resultLattices,
989+
update);
990+
return;
991+
}
967992
}
968993

969994
/// ==========================================================================

compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,71 @@ builtin.module attributes { transform.with_named_sequence } {
3737

3838
// -----
3939

40+
#layout = #iree_vector_ext.nested_layout<
41+
subgroup_tile = [1, 1],
42+
batch_tile = [1, 1],
43+
outer_tile = [1, 1],
44+
thread_tile = [1, 1],
45+
element_tile = [16, 32],
46+
47+
subgroup_strides = [0, 0],
48+
thread_strides = [0, 0]
49+
>
50+
51+
builtin.module attributes { transform.with_named_sequence } {
52+
func.func @transfer_read_mask(%arr: memref<16x32xf16>, %a: vector<16x32xf16>, %b: vector<16x32xf16>, %cond: i1) -> vector<16x32xf16> {
53+
%c0 = arith.constant 0 : index
54+
%c12 = arith.constant 12 : index
55+
%mask = vector.create_mask %c12 : vector<16xi1>
56+
// expected-remark @above {{element_tile = [16]}}
57+
%cst_0 = arith.constant 0.0 : f16
58+
%root = vector.transfer_read %arr[%c0, %c0], %cst_0, %mask {permutation_map = affine_map<(d0, d1) -> (d1, 0)>, in_bounds = [true, true]} : memref<16x32xf16>, vector<16x32xf16>
59+
// expected-remark @above {{element_tile = [16, 32]}}
60+
%rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<16x32xf16>
61+
func.return %rootl : vector<16x32xf16>
62+
}
63+
64+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
65+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
66+
transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
67+
transform.yield
68+
}
69+
}
70+
71+
// -----
72+
73+
#layout = #iree_vector_ext.nested_layout<
74+
subgroup_tile = [1, 1, 1],
75+
batch_tile = [1, 1, 1],
76+
outer_tile = [1, 1, 1],
77+
thread_tile = [1, 1, 1],
78+
element_tile = [16, 8, 4],
79+
80+
subgroup_strides = [0, 0, 0],
81+
thread_strides = [0, 0, 0]
82+
>
83+
84+
builtin.module attributes { transform.with_named_sequence } {
85+
func.func @transfer_write_mask(%arr: memref<32x32x32x32xf16>, %d: vector<16x8x4xf16>) {
86+
%c0 = arith.constant 0 : index
87+
%c12 = arith.constant 12 : index
88+
%mask = vector.create_mask %c12, %c12, %c12 : vector<8x16x4xi1>
89+
// expected-remark @above {{element_tile = [8, 16, 4]}}
90+
%dl = iree_vector_ext.to_layout %d to layout(#layout) : vector<16x8x4xf16>
91+
vector.transfer_write %dl, %arr[%c0, %c0, %c0, %c0], %mask {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>, in_bounds = [true, true, true]} : vector<16x8x4xf16>, memref<32x32x32x32xf16>
92+
return
93+
}
94+
95+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
96+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
97+
transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
98+
transform.yield
99+
}
100+
}
101+
102+
// -----
103+
104+
40105
#layout = #iree_vector_ext.nested_layout<
41106
subgroup_tile = [1, 1],
42107
batch_tile = [1, 1],
@@ -759,3 +824,60 @@ builtin.module attributes { transform.with_named_sequence } {
759824
transform.yield
760825
}
761826
}
827+
828+
// -----
829+
830+
#layout_1d = #iree_vector_ext.nested_layout<
831+
subgroup_tile = [4],
832+
batch_tile = [4],
833+
outer_tile = [1],
834+
thread_tile = [1],
835+
element_tile = [1],
836+
837+
subgroup_strides = [1],
838+
thread_strides = [0]
839+
>
840+
841+
#layout = #iree_vector_ext.nested_layout<
842+
subgroup_tile = [4, 1],
843+
batch_tile = [4, 1],
844+
outer_tile = [1, 1],
845+
thread_tile = [1, 1],
846+
element_tile = [1, 8],
847+
848+
subgroup_strides = [1, 0],
849+
thread_strides = [0, 0]
850+
>
851+
852+
builtin.module attributes { transform.with_named_sequence } {
853+
func.func @paged_transfer_gather(%indices: vector<16xindex>,
854+
%source: memref<4096x512x8xf16>) -> vector<16x8xf16> {
855+
856+
%cst0 = arith.constant 0.0 : f16
857+
%c0 = arith.constant 0 : index
858+
%c1 = arith.constant dense<1> : vector<16xindex>
859+
// expected-remark @above {{element_tile = [1]}}
860+
%c7 = arith.constant 7 : index
861+
%dim = memref.dim %source, %c0 : memref<4096x512x8xf16>
862+
%mask = vector.create_mask %c7, %c7 : vector<16x8xi1>
863+
// expected-remark @above {{element_tile = [1, 8]}}
864+
%indices1 = arith.addi %indices, %c1 : vector<16xindex>
865+
// expected-remark @above {{element_tile = [1]}}
866+
%out = iree_vector_ext.transfer_gather %source[%c0, %c0, %c0]
867+
// expected-remark @above {{element_tile = [1, 8]}}
868+
[None, %indices1: vector<16xindex>, None], %cst0, %mask { indexed_maps = [
869+
affine_map<(d0, d1, d2) -> (d1)>],
870+
permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>,
871+
in_bounds = [true, true] }
872+
: memref<4096x512x8xf16>, vector<16x8xf16>
873+
%l_out = iree_vector_ext.to_layout %out to layout(#layout) : vector<16x8xf16>
874+
875+
return %l_out : vector<16x8xf16>
876+
}
877+
878+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
879+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
880+
transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
881+
transform.yield
882+
}
883+
}

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,36 @@ NestedLayoutAttr::project(ArrayRef<bool> droppedDims) const {
6868
subgroupStrides, threadStrides);
6969
}
7070

71+
VectorLayoutInterface NestedLayoutAttr::apply(AffineMap map) const {
72+
assert(map.getNumDims() == getRank() &&
73+
"map domain size must match layout rank");
74+
75+
SmallVector<int64_t> subgroupCount(map.getNumResults(), 1);
76+
SmallVector<int64_t> batchCount(map.getNumResults(), 1);
77+
SmallVector<int64_t> outerCount(map.getNumResults(), 1);
78+
SmallVector<int64_t> threadCount(map.getNumResults(), 1);
79+
SmallVector<int64_t> elementCount(map.getNumResults(), 1);
80+
SmallVector<int64_t> subgroupStrides(map.getNumResults(), 0);
81+
SmallVector<int64_t> threadStrides(map.getNumResults(), 0);
82+
83+
for (auto [idx, expr] : llvm::enumerate(map.getResults())) {
84+
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
85+
int64_t pos = dim.getPosition();
86+
subgroupCount[idx] = getSubgroupTile()[pos];
87+
batchCount[idx] = getBatchTile()[pos];
88+
outerCount[idx] = getOuterTile()[pos];
89+
threadCount[idx] = getThreadTile()[pos];
90+
elementCount[idx] = getElementTile()[pos];
91+
subgroupStrides[idx] = getSubgroupStrides()[pos];
92+
threadStrides[idx] = getThreadStrides()[pos];
93+
}
94+
}
95+
96+
return NestedLayoutAttr::get(getContext(), subgroupCount, batchCount,
97+
outerCount, threadCount, elementCount,
98+
subgroupStrides, threadStrides);
99+
}
100+
71101
VectorLayoutInterface
72102
NestedLayoutAttr::permute(ArrayRef<int64_t> permutation) const {
73103
SmallVector<int64_t> invPerm = invertPermutationVector(permutation);

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ def VectorLayoutInterface : AttrInterface<"VectorLayoutInterface"> {
3535
/*methodName=*/"project",
3636
/*args=*/(ins "::llvm::ArrayRef<bool>":$droppedDims)
3737
>,
38+
InterfaceMethod<
39+
/*description=*/"Apply the given AffineMap to the layout.",
40+
/*retTy=*/"VectorLayoutInterface",
41+
/*methodName=*/"apply",
42+
/*args=*/(ins "::mlir::AffineMap":$map)
43+
>,
3844
InterfaceMethod<
3945
/*description=*/"Get the expected undistributed shape for the given vector type.",
4046
/*retTy=*/"SmallVector<int64_t>",

0 commit comments

Comments
 (0)