Skip to content

Commit 005fa11

Browse files
authored
[NDArray] Add "permute_dims" operation (#841)
Adding "permute_dims" operation to ndarray
1 parent 17f45dc commit 005fa11

File tree

18 files changed

+614
-21
lines changed

18 files changed

+614
-21
lines changed

include/imex/Dialect/DistRuntime/IR/DistRuntimeOps.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,4 +231,39 @@ def WaitOp : DistRuntime_Op<"wait", []> {
231231
let arguments = (ins DistRuntime_AsyncHandle:$handle);
232232
}
233233

234+
def CopyPermuteOp : DistRuntime_Op<"copy_permute",
235+
[Pure, DeclareOpInterfaceMethods<AsyncOpInterface>, AttrSizedOperandSegments]> {
236+
let summary = "Copy adequate data from input to a new permuted output";
237+
let description = [{
238+
Copy the necessary data from the distributed input array to the locally owned part of the new output array.
239+
The shape of the output array is assumed to be a permuted version of the input's shape.
240+
241+
The local data is not modified.
242+
243+
Arguments:
244+
245+
- `team`: the distributed team owning the distributed array
246+
- `gShape`: the global shape of the distributed array
247+
- `lOffsets`: the offset of the local data within the global array
248+
- `lArray`: the locally owned data
249+
- `nlShape`: the local shape of the distributed output array
250+
- `nlOffsets`: the offsets of the locally owned output array
251+
252+
`gShape`, `lOffsets` are variadic arguments with same size `ri` where `ri` is the rank of the global input array (e.g., one number for each dimension of the global input array).
253+
`ngShape`, `nlOffsets` are variadic arguments with same size `ro` where `ro` is the rank of the global output array (e.g., one number for each dimension of the global output array).
254+
}];
255+
let arguments = (ins AnyAttr:$team,
256+
AnyType:$lArray,
257+
Variadic<Index>:$gShape,
258+
Variadic<Index>:$lOffsets,
259+
Variadic<Index>:$nlShape,
260+
Variadic<Index>:$nlOffsets,
261+
DenseI64ArrayAttr:$axes);
262+
let results = (outs DistRuntime_AsyncHandle:$handle, AnyType:$nlArray);
263+
let assemblyFormat = [{
264+
$lArray `g_shape` $gShape `l_offs` $lOffsets `to` `n_offs` $nlOffsets `n_shape` $nlShape `axes` $axes attr-dict `:` `(` type(operands) `)` `->` `(` qualified(type(results)) `)`
265+
}];
266+
let hasCanonicalizer = 1;
267+
}
268+
234269
#endif // _DistRuntime_OPS_TD_INCLUDED_

include/imex/Dialect/NDArray/IR/NDArrayOps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,4 +948,25 @@ def CastElemTypeOp: NDArray_Op<"cast_elemtype", [Pure]> {
948948
let hasCanonicalizer = 1;
949949
}
950950

951+
def PermuteDimsOp : NDArray_Op<"permute_dims", []> {
952+
let summary = "Permutes the axes (dimensions) of an array to a new array.";
953+
let description = [{
954+
Permutes the axes (dimensions) of an array.
955+
The output array is a new array.
956+
}];
957+
958+
let arguments = (ins
959+
NDArray_NDArray:$source,
960+
DenseI64ArrayAttr:$axes
961+
);
962+
let results = (outs NDArray_NDArray);
963+
964+
let assemblyFormat = [{
965+
$source $axes attr-dict `:` qualified(type($source)) `->` qualified(type(results))
966+
}];
967+
968+
let hasCanonicalizer = 1;
969+
let hasVerifier = 1;
970+
}
971+
951972
#endif // _NDARRAY_OPS_TD_INCLUDED_

lib/Conversion/DistToStandard/DistToStandard.cpp

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,6 +1602,70 @@ struct RePartitionOpConverter
16021602
}
16031603
};
16041604

1605+
/// Convert a global ndarray::PermuteDimsOp on a distributed array
1606+
/// to ndarray::PermuteDimsOp on the local data.
1607+
/// If needed, adds a repartition op.
1608+
/// The local partition (e.g. a RankedTensor) is wrapped in a
1609+
/// non-distributed NDArray and re-applied to PermuteDimsOp.
1610+
/// op gets replaced with global distributed array
1611+
struct PermuteDimsOpConverter
1612+
: public ::mlir::OpConversionPattern<::imex::ndarray::PermuteDimsOp> {
1613+
using ::mlir::OpConversionPattern<
1614+
::imex::ndarray::PermuteDimsOp>::OpConversionPattern;
1615+
1616+
/// Initialize the pattern.
1617+
void initialize() {
1618+
/// Signal that this pattern safely handles recursive application.
1619+
setHasBoundedRewriteRecursion();
1620+
}
1621+
1622+
::mlir::LogicalResult
1623+
matchAndRewrite(::imex::ndarray::PermuteDimsOp op,
1624+
::imex::ndarray::PermuteDimsOp::Adaptor adaptor,
1625+
::mlir::ConversionPatternRewriter &rewriter) const override {
1626+
1627+
auto src = op.getSource();
1628+
auto dst = op.getResult();
1629+
auto srcType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType());
1630+
auto dstType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(dst.getType());
1631+
if (!(srcType && isDist(srcType) && dstType && isDist(dstType))) {
1632+
return ::mlir::failure();
1633+
}
1634+
1635+
auto loc = op.getLoc();
1636+
auto srcEnv = getDistEnv(srcType);
1637+
auto team = srcEnv.getTeam();
1638+
auto elementType = srcType.getElementType();
1639+
1640+
auto srcGShape = createGlobalShapeOf(loc, rewriter, src);
1641+
auto srcLParts = createPartsOf(loc, rewriter, src);
1642+
auto srcLArray = srcLParts.size() == 1 ? srcLParts[0] : srcLParts[1];
1643+
auto srcLOffsets = createLocalOffsetsOf(loc, rewriter, src);
1644+
1645+
auto dstGShape = createGlobalShapeOf(loc, rewriter, dst);
1646+
auto dstLPart = createDefaultPartition(loc, rewriter, team, dstGShape);
1647+
auto dstLOffsets = dstLPart.getLOffsets();
1648+
auto dstLShape = dstLPart.getLShape();
1649+
auto dstLShapeIndex = getShapeFromValues(dstLShape);
1650+
auto dstLType = ::imex::ndarray::NDArrayType::get(
1651+
dstLShapeIndex, elementType, getNonDistEnvs(dstType));
1652+
1653+
// call the dist runtime
1654+
auto handleType = ::imex::distruntime::AsyncHandleType::get(getContext());
1655+
auto distLArray = rewriter.create<::imex::distruntime::CopyPermuteOp>(
1656+
loc, ::mlir::TypeRange{handleType, dstLType}, team, srcLArray,
1657+
srcGShape, srcLOffsets, dstLOffsets, dstLShape, adaptor.getAxes());
1658+
(void)rewriter.create<::imex::distruntime::WaitOp>(loc,
1659+
distLArray.getHandle());
1660+
// finally init dist array
1661+
rewriter.replaceOp(
1662+
op, createDistArray(loc, rewriter, team, srcGShape, dstLOffsets,
1663+
::mlir::ValueRange{distLArray.getNlArray()}));
1664+
1665+
return ::mlir::success();
1666+
}
1667+
};
1668+
16051669
// *******************************
16061670
// ***** Pass infrastructure *****
16071671
// *******************************
@@ -1696,23 +1760,23 @@ struct ConvertDistToStandardPass
16961760
::imex::ndarray::ReductionOp, ::imex::ndarray::ToTensorOp,
16971761
::imex::ndarray::DeleteOp, ::imex::ndarray::CastElemTypeOp,
16981762
::imex::region::EnvironmentRegionOp,
1699-
::imex::region::EnvironmentRegionYieldOp>(
1763+
::imex::region::EnvironmentRegionYieldOp,
1764+
::imex::ndarray::PermuteDimsOp>(
17001765
[&](::mlir::Operation *op) { return typeConverter.isLegal(op); });
17011766
target.addLegalOp<::imex::dist::InitDistArrayOp>();
17021767

17031768
// All the dist conversion patterns/rewriter
17041769
::mlir::RewritePatternSet patterns(&ctxt);
17051770
// all these patterns are converted
1706-
patterns
1707-
.insert<LinSpaceOpConverter, CreateOpConverter, CopyOpConverter,
1708-
ReductionOpConverter, ToTensorOpConverter,
1709-
InsertSliceOpConverter, SubviewOpConverter, EWBinOpConverter,
1710-
EWUnyOpConverter, LocalBoundingBoxOpConverter,
1711-
LocalCoreOpConverter, RePartitionOpConverter,
1712-
ReshapeOpConverter, LocalTargetOfSliceOpConverter,
1713-
DefaultPartitionOpConverter, LocalOffsetsOfOpConverter,
1714-
PartsOfOpConverter, DeleteOpConverter, CastElemTypeOpConverter>(
1715-
typeConverter, &ctxt);
1771+
patterns.insert<
1772+
LinSpaceOpConverter, CreateOpConverter, CopyOpConverter,
1773+
ReductionOpConverter, ToTensorOpConverter, InsertSliceOpConverter,
1774+
SubviewOpConverter, EWBinOpConverter, EWUnyOpConverter,
1775+
LocalBoundingBoxOpConverter, LocalCoreOpConverter,
1776+
RePartitionOpConverter, ReshapeOpConverter,
1777+
LocalTargetOfSliceOpConverter, DefaultPartitionOpConverter,
1778+
LocalOffsetsOfOpConverter, PartsOfOpConverter, DeleteOpConverter,
1779+
CastElemTypeOpConverter, PermuteDimsOpConverter>(typeConverter, &ctxt);
17161780
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
17171781
typeConverter, patterns, target);
17181782
::imex::populateRegionTypeConversionPatterns(patterns, typeConverter);

lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,39 @@ struct ReductionOpLowering
12171217
}
12181218
};
12191219

1220+
/// Convert NDArray's permute_dims operations and their return type to
1221+
/// Linalg/tensor.
1222+
struct PermuteDimsOpLowering
1223+
: public ::mlir::OpConversionPattern<::imex::ndarray::PermuteDimsOp> {
1224+
using OpConversionPattern::OpConversionPattern;
1225+
1226+
::mlir::LogicalResult
1227+
matchAndRewrite(::imex::ndarray::PermuteDimsOp op,
1228+
::imex::ndarray::PermuteDimsOp::Adaptor adaptor,
1229+
::mlir::ConversionPatternRewriter &rewriter) const override {
1230+
1231+
auto loc = op->getLoc();
1232+
auto srcTnsr = adaptor.getSource();
1233+
1234+
// convert src array to memref
1235+
auto srcArType = mlir::dyn_cast_or_null<::imex::ndarray::NDArrayType>(
1236+
op.getSource().getType());
1237+
if (!srcArType)
1238+
return mlir::failure();
1239+
auto srcMRType = srcArType.getMemRefType(srcTnsr);
1240+
auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType);
1241+
1242+
auto perm = ::mlir::AffineMapAttr::get(::mlir::AffineMap::getPermutationMap(
1243+
adaptor.getAxes(), rewriter.getContext()));
1244+
mlir::memref::TransposeOp transposeOp =
1245+
rewriter.create<mlir::memref::TransposeOp>(loc, srcMR, perm);
1246+
1247+
rewriter.replaceOp(op, transposeOp.getResult());
1248+
1249+
return ::mlir::success();
1250+
}
1251+
};
1252+
12201253
// *******************************
12211254
// ***** Pass infrastructure *****
12221255
// *******************************
@@ -1320,13 +1353,14 @@ struct ConvertNDArrayToLinalgPass
13201353
[&](mlir::Operation *op) { return typeConverter.isLegal(op); });
13211354

13221355
::mlir::RewritePatternSet patterns(&ctxt);
1323-
patterns.insert<
1324-
ToTensorLowering, SubviewLowering, ExtractSliceLowering,
1325-
InsertSliceLowering, ImmutableInsertSliceLowering, LinSpaceLowering,
1326-
LoadOpLowering, CreateLowering, EWBinOpLowering, DimOpLowering,
1327-
EWUnyOpLowering, ReductionOpLowering, ReshapeLowering, CastLowering,
1328-
CopyLowering, DeleteLowering, CastElemTypeLowering, FromMemRefLowering>(
1329-
typeConverter, &ctxt);
1356+
patterns.insert<ToTensorLowering, SubviewLowering, ExtractSliceLowering,
1357+
InsertSliceLowering, ImmutableInsertSliceLowering,
1358+
LinSpaceLowering, LoadOpLowering, CreateLowering,
1359+
EWBinOpLowering, DimOpLowering, EWUnyOpLowering,
1360+
ReductionOpLowering, ReshapeLowering, CastLowering,
1361+
CopyLowering, DeleteLowering, CastElemTypeLowering,
1362+
FromMemRefLowering, PermuteDimsOpLowering>(typeConverter,
1363+
&ctxt);
13301364
::imex::populateRegionTypeConversionPatterns(patterns, typeConverter);
13311365

13321366
// populate function boundaries using our special type converter

lib/Dialect/Dist/Transforms/DistCoalesce.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct DistCoalescePass : public ::imex::DistCoalesceBase<DistCoalescePass> {
9999
if (auto op =
100100
isDefByAnyOf<::imex::dist::InitDistArrayOp, ::imex::dist::EWBinOp,
101101
::imex::dist::EWUnyOp, ::imex::ndarray::ReshapeOp,
102+
::imex::ndarray::PermuteDimsOp,
102103
::mlir::UnrealizedConversionCastOp,
103104
::imex::ndarray::CopyOp>(val)) {
104105
return op;

lib/Dialect/DistRuntime/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_imex_dialect_library(IMEXDistRuntimeDialect
22
DistRuntimeOps.cpp
33
GetHaloOp.cpp
44
CopyReshapeOp.cpp
5+
CopyPermuteOp.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
${PROJECT_SOURCE_DIR}/include/imex/Dialect/DistRuntime
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
//===- CopyPermuteOp.cpp - distruntime dialect -----------------*- C++ -*-===//
2+
//
3+
// Copyright 2023 Intel Corporation
4+
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
///
10+
/// \file
11+
/// This file implements the CopyPermuteOp of the DistRuntime dialect.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#include <imex/Dialect/DistRuntime/IR/DistRuntimeOps.h>
16+
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
17+
#include <imex/Utils/PassUtils.h>
18+
19+
namespace imex {
20+
namespace distruntime {
21+
22+
::mlir::SmallVector<::mlir::Value> CopyPermuteOp::getDependent() {
23+
return {getNlArray()};
24+
}
25+
26+
} // namespace distruntime
27+
} // namespace imex
28+
29+
namespace {
30+
31+
/// Pattern to replace dynamically shaped result types
32+
/// by statically shaped result types.
33+
class CopyPermuteOpResultCanonicalizer final
34+
: public mlir::OpRewritePattern<::imex::distruntime::CopyPermuteOp> {
35+
public:
36+
using mlir::OpRewritePattern<
37+
::imex::distruntime::CopyPermuteOp>::OpRewritePattern;
38+
39+
mlir::LogicalResult
40+
matchAndRewrite(::imex::distruntime::CopyPermuteOp op,
41+
::mlir::PatternRewriter &rewriter) const override {
42+
43+
// check input type
44+
auto dstArray = op.getNlArray();
45+
auto dstType =
46+
mlir::dyn_cast<::imex::ndarray::NDArrayType>(dstArray.getType());
47+
if (!dstType) {
48+
return ::mlir::failure();
49+
}
50+
auto dstShape = dstType.getShape();
51+
if (!::mlir::ShapedType::isDynamicShape(dstShape)) {
52+
return ::mlir::failure();
53+
}
54+
55+
auto src = op.getLArray();
56+
auto srcType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType());
57+
if (!srcType) {
58+
return ::mlir::failure();
59+
}
60+
auto srcShape = srcType.getShape();
61+
if (::mlir::ShapedType::isDynamicShape(srcShape)) {
62+
return ::mlir::failure();
63+
}
64+
65+
auto axes = op.getAxes();
66+
if (axes.size() != dstShape.size()) {
67+
return ::mlir::failure();
68+
}
69+
70+
::mlir::SmallVector<int64_t> permutedShape;
71+
for (auto i = 0u; i < srcShape.size(); ++i) {
72+
permutedShape.push_back(srcShape[axes[i]]);
73+
}
74+
75+
auto elType = dstType.getElementType();
76+
auto nType = dstType.cloneWith(permutedShape, elType);
77+
auto hType = ::imex::distruntime::AsyncHandleType::get(getContext());
78+
79+
auto newOp = rewriter.create<::imex::distruntime::CopyPermuteOp>(
80+
op.getLoc(), ::mlir::TypeRange{hType, nType}, op.getTeamAttr(),
81+
op.getLArray(), op.getGShape(), op.getLOffsets(), op.getNlOffsets(),
82+
op.getNlShape(), op.getAxes());
83+
84+
// cast to original types and replace op
85+
auto res = rewriter.create<imex::ndarray::CastOp>(op.getLoc(), dstType,
86+
newOp.getNlArray());
87+
rewriter.replaceOp(op, {newOp.getHandle(), res});
88+
89+
return ::mlir::success();
90+
}
91+
};
92+
93+
/// Pattern to rewrite a subview op with CastOp arguments.
94+
/// Ported from mlir::tensor::ExtractSliceOp
95+
class CopyPermuteCastFolder final
96+
: public mlir::OpRewritePattern<::imex::distruntime::CopyPermuteOp> {
97+
public:
98+
using mlir::OpRewritePattern<
99+
::imex::distruntime::CopyPermuteOp>::OpRewritePattern;
100+
101+
mlir::LogicalResult
102+
matchAndRewrite(::imex::distruntime::CopyPermuteOp op,
103+
mlir::PatternRewriter &rewriter) const override {
104+
auto src = op.getLArray();
105+
auto castOp = mlir::dyn_cast<imex::ndarray::CastOp>(src.getDefiningOp());
106+
if (!castOp)
107+
return mlir::failure();
108+
109+
if (!imex::ndarray::canFoldIntoConsumerOp(castOp))
110+
return mlir::failure();
111+
112+
auto newOp = rewriter.create<::imex::distruntime::CopyPermuteOp>(
113+
op.getLoc(), op->getResultTypes(), op.getTeamAttr(), castOp.getSource(),
114+
op.getGShape(), op.getLOffsets(), op.getNlOffsets(), op.getNlShape(),
115+
op.getAxes());
116+
rewriter.replaceOp(op, newOp);
117+
118+
return mlir::success();
119+
}
120+
};
121+
122+
} // namespace
123+
124+
void imex::distruntime::CopyPermuteOp::getCanonicalizationPatterns(
125+
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
126+
results.add<CopyPermuteOpResultCanonicalizer, CopyPermuteCastFolder>(context);
127+
}

0 commit comments

Comments
 (0)