Skip to content

Commit c50db3b

Browse files
authored
Adding ptensor.extract_slice (#389)
* adding ExtractSliceOp * adding CreateOp * adding InsertSliceOp * separating dist-business from PTensorType-attributes using DistTensorType; no more dist in PTensorType * function boundary handling for Dist * adding and using EasyValue * restructuring Dist-Ops to largely accept ValueRanges instead of memrefs * enabling n-dimensional tensors
1 parent 7e4a623 commit c50db3b

File tree

26 files changed

+2050
-680
lines changed

26 files changed

+2050
-680
lines changed

docs/rfcs/20220804-ptensor/README.md

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ Additionally we propose appropriate passes
2626
3. Converting __intel-sycl.device_region__ to appropriate runtime calls
2727

2828
### ptensor Type
29-
Since operations are expected to execute in the same location as its input tensors, it is necessary to carry the tensor-location from the point of its allocation to the point of the operation. For this, we introduce a type which logically extends the `mlir::tensor` type with two boolean attributes:
30-
* `device`: indicates if it should live on a device
31-
* `dist`: indicates if it should be distributed
29+
Since operations are expected to execute in the same location as its input tensors, it is necessary to carry the tensor-location from the point of its allocation to the point of the operation. For this, we introduce a type which logically extends the `mlir::MemRefType` with a boolean attribute `device`, indicateing if it should live on a device.
3230

3331
The actual device and distributed team can be assigned by the approriate operands of creation operations (see below).
3432

@@ -37,7 +35,7 @@ The tensors themselves are assumed to eventually lower to memrefs.
3735
Notice: By default device and distribution support is disabled and so renders conventional host operations.
3836

3937
### __PTensor__ Operations
40-
The initial set of operations matches the requirements of the core of [array-API](https://data-apis.org/array-api/latest/API_specification/index.html). The operations in the PTensor dialect operate on ptensors. To allow operations on standard tensors and memrefs the PTensor dialect provides the operation `from_ranked` to convert MemRefs and RankedTensors to ptensors with default `device` and `team`.
38+
The initial set of operations matches the requirements of the core of [array-API](https://data-apis.org/array-api/latest/API_specification/index.html). The operations in the PTensor dialect operate on ptensors. To allow operations on standard tensors and memrefs the PTensor dialect provides the operation `from_ranked` to convert MemRefs and MemRefs to ptensors with default `device` and `team`.
4139

4240
Notice: some of the operations mutate existing ptensors.
4341

@@ -50,7 +48,7 @@ It constitutes an error if an operation has multiple (input and output) argument
5048
Similarly, it constitutes an error if an operation has multiple (input and output) arguments of type ptensor and their `team` attribute is not the same on all ptensor arguments.
5149

5250
#### Broadcasting/Ranked Tensors
53-
PTensor operates on ranked tensors. In rare cases the shape of input tensor(s) needs to be known as well. Unranked tensors are not supported.
51+
PTensor operates on MemRefs. In rare cases the shape of input tensor(s) needs to be known as well. Unranked memrefs are not supported.
5452

5553
PTensor operations follow the [broadcasting semantics of the array-API](https://data-apis.org/array-api/latest/API_specification/broadcasting.html).
5654

@@ -76,7 +74,7 @@ The below set of operations accrues from the following rules:
7674
* `$side = ['lower', 'upper']`
7775
* `delete(tensor) : (ptensor) -> void`
7876
* `from_dlpack(obj) : (ptr) -> ptensor.ptensor`
79-
* `from_ranked(ranked) : (Memref|RankedTensor) -> ptensor.ptensor`
77+
* `from_ranked(ranked) : (Memref|MemRef) -> ptensor.ptensor`
8078
* Tensor attributes
8179
* `shape(rhs) : (ptensor.ptensor) -> shape.shape`
8280
* `rank(rhs) : (ptensor.ptensor) -> int64`
@@ -131,18 +129,18 @@ The below set of operations accrues from the following rules:
131129
* `test{$top}(rhs, axis) : (ptensor.ptensor, int) -> ptensor.ptensor`
132130
* `$rop = ['any', 'all']`
133131
* Utility functions not part of the array-API
134-
* Get the (local) ranked tensor from a ptensor:
135-
`extract_rtensor(tensor) : (ptensor.ptensor) -> RankedTensor`
136-
* Initialize a ptensor value from a RankedTensor, device, team and handle:
137-
`init_ptensor(rtensor, device, team, handle) {onDevice : bool, dist : bool} : (RankedTensor, AnyType, AnyType, AnyType, AnyType -> ptensor.ptensor`
132+
* Get the (local) memref from a ptensor:
133+
`extract_memref(tensor) : (ptensor.ptensor) -> MemRef`
134+
* Initialize a ptensor value from a MemRef, device, team and handle:
135+
`init_ptensor(memref, device, team, handle) {onDevice : bool, dist : bool} : (MemRef, AnyType, AnyType, AnyType, AnyType -> ptensor.ptensor`
138136

139137
### __Dist__ Dialect
140138
The Dist dialect provides operations dealing with tensors which are partitioned and distributed across multiple processes. The operations assume some kind of a runtime which handles aspects like communication and partitioning.
141139
- `register_ptensor(shape) : (tensor<?xi64) -> (int64)`
142140
- `unregister_ptensor(dtensor_id) : (i64) -> void`
143141
- `local_shape(dtensor_id) : (i64) -> (tensor<?xi64)`
144142
- `local_offsets(dtensor_id) : (i64) -> (tensor<?xi64)`
145-
- `allreduce(team, op, ltensor) : (i64, i64, RankedTensor) -> void`
143+
- `allreduce(team, op, ltensor) : (i64, i64, MemRef) -> void`
146144

147145
For details watch out for a separate RFC.
148146

@@ -163,9 +161,9 @@ All passes which consume `ptensor`s and -operations comply to compute-follows-da
163161

164162
#### --convert-ptensor-to-linalg
165163
This pass completely lowers ptensor operations:
166-
- __Tensor__: `ptensor.ptensor` will be type-converted to a RankedTensor
164+
- __Tensor__: `ptensor.ptensor` will be type-converted to a MemRef
167165
- Wtihin the pass each PTensor gets "instantiated" by a `init_ptensor` which also accepts `team`, `handle` and `device`. This allows accessing device and distributed runtime attributes during lowering.
168-
- function boundaries are currently not handled explicitly. device and dist information will be lost and normal RankedTensors are exchanged.
166+
- function boundaries are currently not handled explicitly. device and dist information will be lost and normal MemRefs are exchanged.
169167
- __Linalg__: The actual functionality will be represented by one or more operations of the Linalg dialect.
170168
- __intel_sycl__: Appropriate `intel_sycl.device_region` will be put around operations which have inputs of type `ptensor.ptensor` with a non-null `device` attribute.
171169
- utility dialects like __memref__, __shape__, __affine__, __func__ and __arith__

include/imex/Conversion/Passes.td

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def ConvertPTensorToLinalg : Pass<"convert-ptensor-to-linalg", "::mlir::ModuleOp
4343
#### Output IR
4444

4545
- a PTensorType will be lowered to a unrealized_conversion_cast to tuple of
46-
* rtensor: RankedTensor
46+
* tensor: RankedTensor
4747
* device: device where the tensor lives (AnyType, default=None)
4848
* team: a group of processes among which the tensor is distributed
4949
(AnyType, default=None)
@@ -59,9 +59,10 @@ def ConvertPTensorToLinalg : Pass<"convert-ptensor-to-linalg", "::mlir::ModuleOp
5959
let dependentDialects = ["::mlir::linalg::LinalgDialect",
6060
"::mlir::AffineDialect",
6161
"::mlir::func::FuncDialect",
62-
"::mlir::tensor::TensorDialect",
6362
"::mlir::arith::ArithDialect",
64-
"::mlir::shape::ShapeDialect"];
63+
"::mlir::tensor::TensorDialect",
64+
"::mlir::memref::MemRefDialect",
65+
"::mlir::bufferization::BufferizationDialect"];
6566
let options = [];
6667
}
6768

@@ -78,11 +79,11 @@ def ConvertDistToStandard: Pass<"convert-dist-to-standard", "::mlir::ModuleOp">
7879
Necessary prototypes of runtime functions will be added.
7980
}];
8081
let constructor = "::imex::createConvertDistToStandardPass()";
81-
let dependentDialects = ["::mlir::linalg::LinalgDialect",
82+
let dependentDialects = ["::imex::ptensor::PTensorDialect",
83+
"::mlir::linalg::LinalgDialect",
8284
"::mlir::func::FuncDialect",
83-
"::mlir::tensor::TensorDialect",
84-
"::mlir::arith::ArithDialect",
85-
"::mlir::shape::ShapeDialect"];
85+
"::mlir::memref::MemRefDialect",
86+
"::mlir::arith::ArithDialect"];
8687
let options = [];
8788
}
8889

include/imex/Dialect/Dist/IR/DistOps.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,18 @@
1818
#include <mlir/IR/BuiltinTypes.h>
1919
#include <mlir/IR/Dialect.h>
2020
#include <mlir/IR/OpDefinition.h>
21+
#include <mlir/IR/OpImplementation.h>
2122
#include <mlir/IR/Types.h>
2223
#include <mlir/Interfaces/SideEffectInterfaces.h>
2324

2425
namespace imex {
25-
namespace dist {} // namespace dist
26+
namespace ptensor {
27+
class PTensorType;
28+
} // namespace ptensor
29+
namespace dist {
30+
enum INFO : int { GSHAPE, LTENSOR, LOFFSETS, TEAM, INFO_LAST };
31+
extern ::imex::ptensor::PTensorType getPTensorType(::mlir::Value t);
32+
} // namespace dist
2633
} // namespace imex
2734

2835
#include <imex/Dialect/Dist/IR/DistOpsDialect.h.inc>

include/imex/Dialect/Dist/IR/DistOps.td

Lines changed: 195 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,34 @@ def Dist_Dialect : Dialect {
3030

3131
// A longer description of our dialect.
3232
let description = [{
33-
The dist dialect describes interfaces for interacting with
33+
The dist dialect describes interfaces for interacting with
3434
a runtime which handles distributed aspects of PTensor operations.
35-
}];
35+
}];
36+
37+
let dependentDialects = [
38+
"::imex::ptensor::PTensorDialect"
39+
];
3640

3741
// The C++ namespace that the dialect class definition resides in.
3842
let cppNamespace = "::imex::dist";
43+
let useDefaultTypePrinterParser = 1;
44+
}
45+
46+
// common base classes for types in Dist dialect
47+
class Dist_Type<string name, string typeMnemonic, list<Trait> traits = []>
48+
: TypeDef<Dist_Dialect, name, traits> {
49+
let mnemonic = typeMnemonic;
50+
}
51+
52+
def Dist_Tensor : Dist_Type<"DistTensor", "dtensor">
53+
{
54+
let summary = "A type used to bind distributed information to a PTensor";
55+
let description = [{
56+
A distributed PTensor needs information like offset and shape of local partition.
57+
The DistTensor type is used to define operations to carry and extract such information.
58+
}];
59+
let parameters = (ins "::imex::ptensor::PTensorType":$p_tensor_type);
60+
let assemblyFormat = "`<` $p_tensor_type `>`";
3961
}
4062

4163
// Base class for dialect operations. This operation inherits from the base
@@ -46,48 +68,190 @@ def Dist_Dialect : Dialect {
4668
class Dist_Op<string mnemonic, list<Trait> traits = []> :
4769
Op<Dist_Dialect, mnemonic, traits>;
4870

49-
// Add function prototypes used for calling into distributed runtime
5071
def RuntimePrototypesOp : Dist_Op<"runtime_prototypes"> {
72+
let summary = "Add function prototypes used for calling into distributed runtime";
73+
}
74+
75+
def NProcsOp : Dist_Op<"nprocs", [Pure]> {
76+
let summary = "Number of processes for given team";
77+
let arguments = (ins AnyType:$team);
78+
let results = (outs Index);
79+
let builders = [
80+
// auto-deduce return type
81+
OpBuilder<(ins "::mlir::Value":$team), [{
82+
build($_builder, $_state, $_builder.getIndexType(), team);
83+
}]>,
84+
];
5185
}
5286

53-
// Register a ptensor of given shape with a (potentially distributed) runtime.
54-
// Returns an id to uniquely identify the tensor instance in future interactino with the runtime.
55-
// The runtime does not own or manage any PTensor memory. When needed by an operation,
56-
// (local) data needs to be provided.
57-
def RegisterPTensorOp : Dist_Op<"register_ptensor", []> {
58-
// Global shape needed for initial registration. Views are handled by a separate op.
59-
let arguments = (ins AnyType: $shape);
87+
def PRankOp : Dist_Op<"prank", [Pure]> {
88+
let summary = "Process rank in team";
89+
let arguments = (ins AnyType:$team);
90+
let results = (outs Index);
91+
let builders = [
92+
// auto-deduce return type
93+
OpBuilder<(ins "::mlir::Value":$team), [{
94+
build($_builder, $_state, $_builder.getIndexType(), team);
95+
}]>,
96+
];
97+
}
6098

61-
// result is an Integer Id
62-
let results = (outs I64);
99+
def InitDistTensorOp : Dist_Op<"init_dist_tensor", [SameVariadicOperandSize, Pure]> {
100+
let summary = "Bind a PTensor to distributed meta information";
101+
let description = [{
102+
The attached PTensor is the local partiton of the distributed PTensor.
103+
The distributed meta information about a new PTensor provides
104+
- the global shape
105+
- the process-local offsets
106+
- the distributed team
107+
}];
108+
let arguments = (ins Variadic<Index>:$g_shape, AnyType:$p_tensor, Variadic<Index>:$l_offsets, AnyType:$team);
109+
let results = (outs Dist_Tensor);
110+
let builders = [
111+
// auto-deduce return type
112+
OpBuilder<(ins "::mlir::ValueRange":$g_shape, "::mlir::Value":$p_tensor, "::mlir::ValueRange":$l_offsets, "::mlir::Value":$team), [{
113+
build($_builder, $_state,
114+
::imex::dist::DistTensorType::get($_builder.getContext(), p_tensor.getType().dyn_cast<::imex::ptensor::PTensorType>()),
115+
g_shape, p_tensor, l_offsets, team);
116+
}]>,
117+
];
63118
}
64119

65-
// Get the offsets (one for each dimension) of the local partition of a distributed PTensor in number of elements.
66-
// Partitionings can be N-dimensional but must cut only the first N dimensions.
67-
def LocalOffsetsOp : Dist_Op<"local_offsets", []> {
68-
// Id of tensor as returned by RegisterPTensorOp
69-
let arguments = (ins I64Attr: $rank, I64: $ptensor);
120+
def GlobalShapeOfOp : Dist_Op<"global_shape_of", []> {
121+
let summary = "Get global shape of distributed tensor.";
122+
let arguments = (ins AnyType:$d_tensor);
123+
let results = (outs Variadic<Index>:$g_shape);
124+
let builders = [
125+
// auto-deduce return type from from operands
126+
OpBuilder<(ins "::mlir::Value":$d_tensor), [{
127+
auto rank = d_tensor.getType().dyn_cast<::imex::dist::DistTensorType>().getPTensorType().getRank();
128+
auto IndexType = $_builder.getIndexType();
129+
::mlir::SmallVector<::mlir::Type> rt(rank, IndexType);
130+
build($_builder, $_state, ::mlir::TypeRange(rt), d_tensor);
131+
}]>,
132+
];
133+
}
70134

71-
// result is a 1d memref
72-
let results = (outs AnyType);
135+
def LocalOffsetsOfOp : Dist_Op<"local_offsets_of", []> {
136+
let summary = "Get local offsets of distributed tensor.";
137+
let arguments = (ins AnyType:$d_tensor);
138+
let results = (outs Variadic<Index>:$l_offsets);
139+
let builders = [
140+
// auto-deduce return type from from operands
141+
OpBuilder<(ins "::mlir::Value":$d_tensor), [{
142+
auto rank = d_tensor.getType().dyn_cast<::imex::dist::DistTensorType>().getPTensorType().getRank();
143+
auto IndexType = $_builder.getIndexType();
144+
::mlir::SmallVector<::mlir::Type> rt(rank, IndexType);
145+
build($_builder, $_state, ::mlir::TypeRange(rt), d_tensor);
146+
}]>,
147+
];
73148
}
74149

75-
// Get the shape (one size for each dimension) of the local partition of a distributed PTensor in number of elements.
76-
// Partitionings can be N-dimensional but must cut only the first N dimensions.
77-
def LocalShapeOp : Dist_Op<"local_shape", []> {
78-
// Id of tensor as returned by RegisterPTensorOp
79-
let arguments = (ins I64Attr: $rank, I64: $ptensor);
150+
def LocalTensorOfOp : Dist_Op<"local_tensor_of", []> {
151+
let summary = "Get local tensor of distributed tensor.";
152+
let arguments = (ins AnyType:$d_tensor);
153+
let results = (outs AnyType:$l_tensor);
154+
let builders = [
155+
// auto-deduce return type from from operands
156+
OpBuilder<(ins "::mlir::Value":$d_tensor), [{
157+
auto ttype = d_tensor.getType().dyn_cast<::imex::dist::DistTensorType>();
158+
build($_builder, $_state, ttype.getPTensorType(), d_tensor);
159+
}]>,
160+
];
161+
}
80162

81-
// result is a 1d memref
82-
let results = (outs AnyType);
163+
def TeamOfOp : Dist_Op<"team_of", []> {
164+
let summary = "Get team of distributed tensor.";
165+
let arguments = (ins AnyType:$d_tensor);
166+
let results = (outs AnyType:$team);
167+
let builders = [
168+
// auto-deduce return type from from operands
169+
OpBuilder<(ins "::mlir::Value":$d_tensor), [{
170+
build($_builder, $_state, $_builder.getIndexType(), d_tensor);
171+
}]>,
172+
];
83173
}
84174

85-
// Inplace allreduce
86-
def AllReduceOp : Dist_Op<"allreduce", []> {
87-
// reduction operation and and local tensor
88-
let arguments = (ins AnyAttr: $op, AnyTensor: $tensor);
175+
def LocalPartitionOp : Dist_Op<"local_partition", [SameVariadicResultSize, Pure]> {
176+
let summary = "Compute the shape and offsets of the local partition in number of elements (one for each dimension).";
177+
let arguments = (ins Index:$num_procs, Index:$p_rank, Variadic<Index>:$g_shape);
178+
let results = (outs Variadic<Index>:$l_offsets, Variadic<Index>:$l_shape);
179+
let builders = [
180+
// auto-deduce return type
181+
OpBuilder<(ins "::mlir::Value":$num_procs, "::mlir::Value":$prank, "::mlir::ValueRange":$gshape), [{
182+
auto IndexType = $_builder.getIndexType();
183+
::mlir::SmallVector<::mlir::Type> rt(gshape.size()*2, IndexType);
184+
build($_builder,
185+
$_state,
186+
::mlir::TypeRange(rt),
187+
num_procs,
188+
prank,
189+
gshape);
190+
}]>,
191+
];
192+
}
193+
194+
def LocalOfSliceOp : Dist_Op<"local_of_slice",
195+
[SameVariadicOperandSize, SameVariadicResultSize, Pure]> {
196+
let summary = "Compute local overlap of a distributed tensor and slice";
197+
let description = [{
198+
Slice and tensor operate on the global index space. This operation computes the
199+
local part of the slice as owned by the local partition of the tensor. The operation
200+
returns local offsets and sizes (e.g. relative to the local memref). Additionally,
201+
it computes and returns the offsets of the resulting local slice relative to the global input slice.
202+
}];
203+
204+
let arguments = (ins
205+
AnyType:$d_tensor,
206+
Variadic<Index>:$offsets,
207+
Variadic<Index>:$sizes,
208+
Variadic<Index>:$strides
209+
);
210+
let results = (outs Variadic<Index>:$l_offsets, Variadic<Index>:$l_sizes, Variadic<Index>:$g_offsets);
211+
212+
let assemblyFormat = [{
213+
$d_tensor `[` $offsets `]``[` $sizes `]``[` $strides `]` attr-dict `:` qualified(type($d_tensor)) `to` `(`qualified(type(results))`)`
214+
}];
215+
216+
let builders = [
217+
// auto-deduce return type
218+
OpBuilder<(ins "::mlir::Value":$d_tensor, "::mlir::ValueRange":$offsets, "::mlir::ValueRange":$sizes, "::mlir::ValueRange":$strides), [{
219+
auto IndexType = $_builder.getIndexType();
220+
::mlir::SmallVector<::mlir::Type> rt(offsets.size()*3, IndexType);
221+
build($_builder, $_state, ::mlir::TypeRange(rt), d_tensor, offsets, sizes, strides);
222+
}]>,
223+
];
224+
}
89225

90-
// result is allreduced input tensor
226+
def LocalToGlobalOp : Dist_Op<"local_to_global", [Pure]> {
227+
let summary = "Translate local indices into global indices";
228+
let description = [{
229+
Input indices are interprete as relative to the local part of the given DTensor.
230+
}];
231+
232+
let arguments = (ins AnyType:$d_tensor, Variadic<Index>:$l_indices);
233+
let results = (outs Variadic<Index>:$g_indices);
234+
235+
let builders = [
236+
// auto-deduce return type
237+
OpBuilder<(ins "::mlir::Value":$d_tensor, "::mlir::ValueRange":$lindices), [{
238+
auto IndexType = $_builder.getIndexType();
239+
::mlir::SmallVector<::mlir::Type> rt(lindices.size(), IndexType);
240+
build($_builder, $_state, ::mlir::TypeRange(rt), d_tensor, lindices);
241+
}]>,
242+
];
243+
// let assemblyFormat = [{
244+
// $d_tensor attr-dict `:` qualified(type($source)) `to` `(`qualified(type(results))`)`
245+
// }];
246+
}
247+
248+
def AllReduceOp : Dist_Op<"allreduce", []> {
249+
let summary = "Inplace allreduce";
250+
let description = [{
251+
Result is the allreduced input tensor.
252+
}];
253+
// reduction operation and local tensor
254+
let arguments = (ins AnyAttr:$op, AnyMemRef:$data);
91255
let results = (outs AnyType);
92256
}
93257

0 commit comments

Comments
 (0)