Skip to content

Commit 8e9f389

Browse files
author
Ivan Garcia
committed
Addressing feedback from Matthias Springer.
1 parent 2dac9bb commit 8e9f389

File tree

1 file changed

+131
-102
lines changed

1 file changed

+131
-102
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 131 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,54 +1332,33 @@ def MemRef_ReinterpretCastOp
13321332
Modify offset, sizes and strides of an unranked/ranked memref.
13331333

13341334
Example 1:
1335-
```mlir
1336-
memref.reinterpret_cast %ranked to
1337-
offset: [0],
1338-
sizes: [%size0, 10],
1339-
strides: [1, %stride1]
1340-
: memref<?x?xf32> to memref<?x10xf32, strided<[1, ?], offset: 0>>
1341-
1342-
memref.reinterpret_cast %unranked to
1343-
offset: [%offset],
1344-
sizes: [%size0, %size1],
1345-
strides: [%stride0, %stride1]
1346-
: memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
1347-
```
13481335

1349-
This operation creates a new memref descriptor using the base of the
1350-
source and applying the input arguments to the other metadata.
1351-
In other words:
1352-
```mlir
1353-
%dst = memref.reinterpret_cast %src to
1354-
offset: [%offset],
1355-
sizes: [%sizes],
1356-
strides: [%strides]
1357-
```
1358-
means that `%dst`'s descriptor will be:
1359-
```mlir
1360-
%dst.base = %src.base
1361-
%dst.aligned = %src.aligned
1362-
%dst.offset = %offset
1363-
%dst.sizes = %sizes
1364-
%dst.strides = %strides
1365-
```
1366-
1367-
Example 2:
1368-
1369-
Consecutive `reinterpret_cast` operations on memref's with static dimensions.
1336+
Consecutive `reinterpret_cast` operations on memref's with static
1337+
dimensions.
13701338

13711339
We distinguish between *underlying memory* — the sequence of elements as
1372-
they appear in the contiguous memory of the memref — and the *view*, which refers to
1373-
the underlying memory interpreted according to specified offsets, sizes, and strides.
1340+
they appear in the contiguous memory of the memref — and the
1341+
*strided memref*, which refers to the underlying memory interpreted
1342+
according to specified offsets, sizes, and strides.
13741343

13751344
```mlir
1376-
%result1 = memref.reinterpret_cast %arg0 to offset: [9], sizes: [4, 4], strides: [16, 2] : memref<8x8xf32, strided<[8, 1], offset: 0>> to memref<4x4xf32, strided<[16, 2], offset: 9>>
1377-
1378-
%result2 = memref.reinterpret_cast %result1 to offset: [0], sizes: [2, 2], strides: [4, 2] : memref<4x4xf32, strided<[16, 2], offset: 9>> to memref<2x2xf32, strided<[4, 2], offset: 0>>
1345+
%result1 = memref.reinterpret_cast %arg0 to
1346+
offset: [9],
1347+
sizes: [4, 4],
1348+
strides: [16, 2]
1349+
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
1350+
memref<4x4xf32, strided<[16, 2], offset: 9>>
1351+
1352+
%result2 = memref.reinterpret_cast %result1 to
1353+
offset: [0],
1354+
sizes: [2, 2],
1355+
strides: [4, 2]
1356+
: memref<4x4xf32, strided<[16, 2], offset: 9>> to
1357+
memref<2x2xf32, strided<[4, 2], offset: 0>>
13791358
```
13801359

1381-
The input memref `%arg0` has the following view. The underlying memory consists
1382-
of a linear sequence of integers from 1 to 64:
1360+
The underlying memory of `%arg0` consists of a linear sequence of integers
1361+
from 1 to 64. Its memref has the following 8x8 elements:
13831362

13841363
```mlir
13851364
[[1, 2, 3, 4, 5, 6, 7, 8],
@@ -1392,7 +1371,8 @@ def MemRef_ReinterpretCastOp
13921371
[57, 58, 59, 60, 61, 62, 63, 64]]
13931372
```
13941373

1395-
Following the first `reinterpret_cast`, the view of `%result1` is:
1374+
Following the first `reinterpret_cast`, the strided memref elements
1375+
of `%result1` are:
13961376

13971377
```mlir
13981378
[[10, 12, 14, 16],
@@ -1401,20 +1381,60 @@ def MemRef_ReinterpretCastOp
14011381
[58, 60, 62, 64]]
14021382
```
14031383

1404-
Note: The offset and strides are relative to the underlying memory of `%arg0`.
1384+
Note: The offset and strides are relative to the underlying memory of
1385+
`%arg0`.
14051386

1406-
The second `reinterpret_cast` results in the following view for `%result2`:
1387+
The second `reinterpret_cast` results in the following strided memref
1388+
for `%result2`:
14071389

14081390
```mlir
14091391
[[1, 3],
14101392
[5, 7]]
14111393
```
14121394

1413-
It is important to observe that the offset and stride are relative to the base underlying
1414-
memory of the memref, starting at 1, not at 10 as seen in the output of `%result1`.
1415-
This behavior contrasts with the `subview` operator, where values are relative to the view of
1416-
the memref (refer to `subview` examples). Consequently, the second `reinterpret_cast` behaves
1417-
as if `%arg0` were passed directly as its argument.
1395+
Notice that it does not matter if you use %result1 or %arg0 as a source
1396+
for the second `reinterpret_cast` operation. Only the underlying memory
1397+
pointers will be reused.
1398+
1399+
The offset and stride are relative to the base underlying memory of the
1400+
memref, starting at 1, not at 10 as seen in the output of `%result1`.
1401+
This behavior contrasts with the `subview` operator, where values are
1402+
relative to the strided memref (refer to `subview` examples).
1403+
Consequently, the second `reinterpret_cast` behaves as if `%arg0` were
1404+
passed directly as its argument.
1405+
1406+
Example 2:
1407+
```mlir
1408+
memref.reinterpret_cast %ranked to
1409+
offset: [0],
1410+
sizes: [%size0, 10],
1411+
strides: [1, %stride1]
1412+
: memref<?x?xf32> to memref<?x10xf32, strided<[1, ?], offset: 0>>
1413+
1414+
memref.reinterpret_cast %unranked to
1415+
offset: [%offset],
1416+
sizes: [%size0, %size1],
1417+
strides: [%stride0, %stride1]
1418+
: memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
1419+
```
1420+
1421+
This operation creates a new memref descriptor using the base of the
1422+
source and applying the input arguments to the other metadata.
1423+
In other words:
1424+
```mlir
1425+
%dst = memref.reinterpret_cast %src to
1426+
offset: [%offset],
1427+
sizes: [%sizes],
1428+
strides: [%strides]
1429+
```
1430+
means that `%dst`'s descriptor will be:
1431+
```mlir
1432+
%dst.base = %src.base
1433+
%dst.aligned = %src.aligned
1434+
%dst.offset = %offset
1435+
%dst.sizes = %sizes
1436+
%dst.strides = %strides
1437+
```
14181438
}];
14191439

14201440
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "", []>:$source,
@@ -1950,6 +1970,64 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19501970

19511971
Example 1:
19521972

1973+
Consecutive `subview` operations on memref's with static dimensions.
1974+
1975+
We distinguish between *underlying memory* — the sequence of elements as
1976+
they appear in the contiguous memory of the memref — and the
1977+
*strided memref*, which refers to the underlying memory interpreted
1978+
according to specified offsets, sizes, and strides.
1979+
1980+
```mlir
1981+
%result1 = memref.subview %arg0[1, 1][4, 4][2, 2]
1982+
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
1983+
memref<4x4xf32, strided<[16, 2], offset: 9>>
1984+
1985+
%result2 = memref.subview %result1[1, 1][2, 2][2, 2]
1986+
: memref<4x4xf32, strided<[16, 2], offset: 9>> to
1987+
memref<2x2xf32, strided<[32, 4], offset: 27>>
1988+
```
1989+
1990+
The underlying memory of `%arg0` consists of a linear sequence of integers
1991+
from 1 to 64. Its memref has the following 8x8 elements:
1992+
1993+
```mlir
1994+
[[1, 2, 3, 4, 5, 6, 7, 8],
1995+
[9, 10, 11, 12, 13, 14, 15, 16],
1996+
[17, 18, 19, 20, 21, 22, 23, 24],
1997+
[25, 26, 27, 28, 29, 30, 31, 32],
1998+
[33, 34, 35, 36, 37, 38, 39, 40],
1999+
[41, 42, 43, 44, 45, 46, 47, 48],
2000+
[49, 50, 51, 52, 53, 54, 55, 56],
2001+
[57, 58, 59, 60, 61, 62, 63, 64]]
2002+
```
2003+
2004+
Following the first `subview`, the strided memref elements of `%result1`
2005+
are:
2006+
2007+
```mlir
2008+
[[10, 12, 14, 16],
2009+
[26, 28, 30, 32],
2010+
[42, 44, 46, 48],
2011+
[58, 60, 62, 64]]
2012+
```
2013+
2014+
Note: The offset and strides are relative to the strided memref of `%arg0`
2015+
(compare to the corresponding `reinterpret_cast` example).
2016+
2017+
The second `subview` results in the following strided memref for
2018+
`%result2`:
2019+
2020+
```mlir
2021+
[[28, 32],
2022+
[60, 64]]
2023+
```
2024+
2025+
Unlike the `reinterpret_cast`, the values are relative to the strided
2026+
memref of the input (`%result1` in this case) and not its
2027+
underlying memory.
2028+
2029+
Example 2:
2030+
19532031
```mlir
19542032
// Subview of static memref with strided layout at static offsets, sizes
19552033
// and strides.
@@ -1958,7 +2036,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19582036
memref<8x2xf32, strided<[21, 18], offset: 137>>
19592037
```
19602038

1961-
Example 2:
2039+
Example 3:
19622040

19632041
```mlir
19642042
// Subview of static memref with identity layout at dynamic offsets, sizes
@@ -1967,7 +2045,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19672045
: memref<64x4xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
19682046
```
19692047

1970-
Example 3:
2048+
Example 4:
19712049

19722050
```mlir
19732051
// Subview of dynamic memref with strided layout at dynamic offsets and
@@ -1977,7 +2055,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19772055
memref<4x4xf32, strided<[?, ?], offset: ?>>
19782056
```
19792057

1980-
Example 4:
2058+
Example 5:
19812059

19822060
```mlir
19832061
// Rank-reducing subviews.
@@ -1987,62 +2065,13 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19872065
: memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
19882066
```
19892067

1990-
Example 5:
2068+
Example 6:
19912069

19922070
```mlir
19932071
// Identity subview. The subview is the full source memref.
19942072
%1 = memref.subview %0[0, 0, 0] [8, 16, 4] [1, 1, 1]
19952073
: memref<8x16x4xf32> to memref<8x16x4xf32>
19962074
```
1997-
Example 6:
1998-
1999-
Consecutive `subview` operations on memref's with static dimensions.
2000-
2001-
We distinguish between *underlying memory* — the sequence of elements as
2002-
they appear in the contiguous memory of the memref — and the *view*, which refers to
2003-
the underlying memory interpreted according to specified offsets, sizes, and strides.
2004-
2005-
```mlir
2006-
%result1 = memref.subview %arg0[1, 1][4, 4][2, 2] : memref<8x8xf32, strided<[8, 1], offset: 0>> to memref<4x4xf32, strided<[16, 2], offset: 9>>
2007-
2008-
%result2 = memref.subview %result1[1, 1][2, 2][2, 2] : memref<4x4xf32, strided<[16, 2], offset: 9>> to memref<2x2xf32, strided<[32, 4], offset: 27>>
2009-
```
2010-
2011-
The input memref `%arg0` has the following view. The underlying memory for this input
2012-
memref is a linear sequence of integers from 1 to 64:
2013-
2014-
```mlir
2015-
[[1, 2, 3, 4, 5, 6, 7, 8],
2016-
[9, 10, 11, 12, 13, 14, 15, 16],
2017-
[17, 18, 19, 20, 21, 22, 23, 24],
2018-
[25, 26, 27, 28, 29, 30, 31, 32],
2019-
[33, 34, 35, 36, 37, 38, 39, 40],
2020-
[41, 42, 43, 44, 45, 46, 47, 48],
2021-
[49, 50, 51, 52, 53, 54, 55, 56],
2022-
[57, 58, 59, 60, 61, 62, 63, 64]]
2023-
```
2024-
2025-
Following the first `subview`, the view of `%result1` is:
2026-
2027-
```mlir
2028-
[[10, 12, 14, 16],
2029-
[26, 28, 30, 32],
2030-
[42, 44, 46, 48],
2031-
[58, 60, 62, 64]]
2032-
```
2033-
2034-
Note: The offset and strides are relative to the memref view of `%arg0` (compare to the
2035-
corresponding `reinterpret_cast` example).
2036-
2037-
The second `subview` results in the following view for `%result2`:
2038-
2039-
```mlir
2040-
[[28, 32],
2041-
[60, 64]]
2042-
```
2043-
2044-
Unlike the `reinterpret_cast`, the values are relative to the view of the input memref
2045-
(`%result1` in this case) and not its underlying memory.
20462075
}];
20472076

20482077
let arguments = (ins AnyMemRef:$source,

0 commit comments

Comments
 (0)