Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
let description = [{
This pass converts communication operations from the Mesh dialect to the
MPI dialect.
If it finds a global named "static_mpi_rank" it will use that splat value
instead of calling MPI_Comm_rank. This allows optimizations like constant
shape propagation and fusion because shard/partition sizes depend on the
rank.
If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
use that integer value instead of calling MPI_Comm_rank. This allows
optimizations like constant shape propagation and fusion because
shard/partition sizes depend on the rank.
}];
let dependentDialects = [
"memref::MemRefDialect",
Expand Down
22 changes: 15 additions & 7 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
}];
}

def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
let summary = "Get the shard shape of a given process/device.";
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
Pure, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the shard shape for a given process/device.";
let description = [{
The device/process id is a linearized id of the device/process in the mesh.
The device/process id is a multi-index of the device/process in the mesh.
This operation might be used during spmdization when the shard shape depends
on (non-constant) values used in `mesh.sharding`.
}];
let arguments = (ins
DenseI64ArrayAttr:$shape,
DenseI64ArrayAttr:$dims,
Variadic<Index>:$dims_dynamic,
Mesh_Sharding:$sharding,
Index:$device
DenseI64ArrayAttr:$device,
Variadic<Index>:$device_dynamic
);
let results = (outs Variadic<Index>:$result);
let assemblyFormat = [{
custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
`dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
`sharding` `=` $sharding
`device` `=` custom<DynamicIndexList>($device_dynamic, $device)
attr-dict `:` type(results)
}];
let builders = [
OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
OpBuilder<(ins "ArrayRef<int64_t>":$dims, "ArrayRef<Value>":$dims_dyn, "Value":$sharding, "ValueRange":$device)>
];
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
Core

LINK_LIBS PUBLIC
MLIRDLTIDialect
MLIRFuncDialect
MLIRIR
MLIRLinalgTransforms
Expand Down
Loading