Skip to content

Commit 2796336

Browse files
authored
[mlir][vector] Improve vector.gather description (#153278)
Improve/elaborate example describing semantics
1 parent 4936fc5 commit 2796336

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,39 +2058,52 @@ def Vector_GatherOp :
20582058
Results<(outs AnyVectorOfNonZeroRank:$result)> {
20592059

20602060
let summary = [{
2061-
gathers elements from memory or ranked tensor into a vector as defined by an
2062-
index vector and a mask vector
2061+
Gathers elements from memory or ranked tensor into a vector as defined by an
2062+
index vector and a mask vector.
20632063
}];
20642064

20652065
let description = [{
20662066
The gather operation returns an n-D vector whose elements are either loaded
2067-
from memory or ranked tensor, or taken from a pass-through vector, depending
2067+
from a k-D memref or tensor, or taken from an n-D pass-through vector, depending
20682068
on the values of an n-D mask vector.
2069-
If a mask bit is set, the corresponding result element is defined by the base
2070-
with indices and the n-D index vector (each index is a 1-D offset on the base).
2071-
Otherwise, the corresponding element is taken from the n-D pass-through vector.
2072-
Informally the semantics are:
2069+
2070+
If a mask bit is set, the corresponding result element is taken from `base`
2071+
at an index defined by k indices and n-D `index_vec`. Otherwise, the element
2072+
is taken from the pass-through vector. As an example, suppose that `base` is
2073+
3-D and the result is 2-D:
2074+
2075+
```mlir
2076+
func.func @gather_3D_to_2D(
2077+
%base: memref<?x10x?xf32>, %i0: index, %i1: index, %i2: index,
2078+
%index_vec: vector<2x3xi32>, %mask: vector<2x3xi1>,
2079+
%fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
2080+
%result = vector.gather %base[%i0, %i1, %i2]
2081+
[%index_vec], %mask, %fall_thru : [...]
2082+
return %result : vector<2x3xf32>
2083+
}
20732084
```
2074-
result[0] := if mask[0] then base[index[0]] else pass_thru[0]
2075-
result[1] := if mask[1] then base[index[1]] else pass_thru[1]
2076-
etc.
2085+
2086+
The indexing semantics are then,
2087+
2088+
```
2089+
result[i,j] := if mask[i,j] then base[i0, i1, i2 + index_vec[i,j]]
2090+
else pass_thru[i,j]
20772091
```
2092+
The index into `base` only varies in the innermost ((k-1)-th) dimension.
20782093

20792094
If a mask bit is set and the corresponding index is out-of-bounds for the
20802095
given base, the behavior is undefined. If a mask bit is not set, the value
20812096
comes from the pass-through vector regardless of the index, and the index is
20822097
allowed to be out-of-bounds.
20832098

2084-
The gather operation can be used directly where applicable, or can be used
2085-
during progressively lowering to bring other memory operations closer to
2086-
hardware ISA support for a gather.
2087-
20882099
Examples:
20892100

20902101
```mlir
2102+
// 1-D memref gathered to 2-D vector.
20912103
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru
20922104
: memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
20932105

2106+
// 2-D memref gathered to 1-D vector.
20942107
%1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
20952108
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
20962109
```

0 commit comments

Comments
 (0)