Skip to content

Commit 7be3da8

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents a320c6e + a6109b1 commit 7be3da8

File tree

15 files changed

+297
-73
lines changed

15 files changed

+297
-73
lines changed

include/imex/Dialect/Dist/Utils/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ inline ::mlir::ValueRange createLocalOffsetsOf(const ::mlir::Location &loc,
206206
return builder.create<::imex::dist::LocalOffsetsOfOp>(loc, ary).getLOffsets();
207207
}
208208

209-
// create operation returning halo of distributed array
209+
// create operation returning all parts (owned + halos) of distributed array
210210
inline ::mlir::ValueRange createPartsOf(const ::mlir::Location &loc,
211211
::mlir::OpBuilder &builder,
212212
::mlir::Value ary) {

include/imex/Dialect/DistRuntime/IR/DistRuntimeOps.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,37 @@ def GetHaloOp : DistRuntime_Op<"get_halo",
191191
let hasCanonicalizer = 1;
192192
}
193193

194+
def CopyReshapeOp : DistRuntime_Op<"copy_reshape",
195+
[Pure, DeclareOpInterfaceMethods<AsyncOpInterface>, AttrSizedOperandSegments]> {
196+
let summary = "Copy adequate data from input to a new reshaped output";
197+
let description = [{
198+
Copy the necessary data from the distributed input array to the locally owned part of the output array.
199+
The shape of the output array is assumed to be a reshaped version of the input's shape.
200+
201+
The local data is not modified.
202+
203+
Arguments:
204+
205+
- `team`: the distributed team owning the distributed array
206+
- `gShape`: the global shape of the distributed array
207+
- `lOffsets`: the offset of the local data within the global array
208+
- `lArray`: the locally owned data
209+
- `ngShape`: the global shape of the distributed output array
210+
- `nlOffsets`: the offsets of the locally owned output array
211+
212+
`gShape`, `lOffsets` are variadic arguments with same size `ri` where `ri` is the rank of the global input array (e.g., one number for each dimension of the global input array).
213+
`ngShape`, `nlOffsets` are variadic arguments with same size `ro` where `ro` is the rank of the global output array (e.g., one number for each dimension of the global output array).
214+
}];
215+
let arguments = (ins AnyAttr:$team,
216+
AnyType:$lArray, Variadic<Index>:$gShape, Variadic<Index>:$lOffsets,
217+
Variadic<Index>:$ngShape, Variadic<Index>:$nlOffsets, Variadic<Index>:$nlShape);
218+
let results = (outs DistRuntime_AsyncHandle:$handle, AnyType:$nlArray);
219+
let assemblyFormat = [{
220+
$lArray `g_shape` $gShape `l_offs` $lOffsets `to` `n_g_shape` $ngShape `n_offs` $nlOffsets `n_shape` $nlShape attr-dict `:` `(` type(operands) `)` `->` `(` qualified(type(results)) `)`
221+
}];
222+
let hasCanonicalizer = 1;
223+
}
224+
194225
def WaitOp : DistRuntime_Op<"wait", []> {
195226
let summary = "Wait for asynchronous operation to finish.";
196227
let description = [{

include/imex/Dialect/NDArray/IR/NDArrayOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,11 +875,11 @@ def ReshapeOp : NDArray_Op<"reshape", []> {
875875
See Array API.
876876
}];
877877

878-
let arguments = (ins AnyType:$src, Variadic<Index>:$shape, OptionalAttr<I1Attr>:$copy);
878+
let arguments = (ins AnyType:$source, Variadic<Index>:$shape, OptionalAttr<I1Attr>:$copy);
879879
let results = (outs AnyType);
880880

881881
let assemblyFormat = [{
882-
$src $shape attr-dict `:` qualified(type($src)) `->` qualified(type(results))
882+
$source $shape attr-dict `:` qualified(type($source)) `->` qualified(type(results))
883883
}];
884884
}
885885

lib/Conversion/DistToStandard/DistToStandard.cpp

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -591,9 +591,7 @@ struct ReshapeOpConverter
591591
::imex::ndarray::ReshapeOp::Adaptor adaptor,
592592
::mlir::ConversionPatternRewriter &rewriter) const override {
593593

594-
return ::mlir::failure();
595-
#ifdef FIXME_RESHAPE
596-
auto src = op.getArray();
594+
auto src = op.getSource();
597595
auto srcDistType = src.getType().dyn_cast<::imex::ndarray::NDArrayType>();
598596
auto retDistType =
599597
op.getResult().getType().dyn_cast<::imex::ndarray::NDArrayType>();
@@ -603,10 +601,13 @@ struct ReshapeOpConverter
603601
}
604602

605603
auto loc = op.getLoc();
604+
auto elType = srcDistType.getElementType();
606605
auto dEnv = getDistEnv(srcDistType);
607-
auto nShape = adaptor.getShape();
606+
auto ngShape = adaptor.getShape();
608607
auto gShape = createGlobalShapeOf(loc, rewriter, src);
609-
auto elType = srcDistType.getElementType();
608+
auto lParts = createPartsOf(loc, rewriter, src);
609+
auto lArray = lParts.size() == 1 ? lParts[0] : lParts[1];
610+
auto lOffs = createLocalOffsetsOf(loc, rewriter, src);
610611

611612
// Repartitioning is needed if any of the partitions' size is not a multiple
612613
// of the new chunksize.
@@ -615,37 +616,30 @@ struct ReshapeOpConverter
615616
assert(adaptor.getCopy().value_or(1) != 0 ||
616617
!"Distributed reshape currently requires copying");
617618

618-
auto team = dEnv.getTeam();
619-
// get function args
620-
auto lPart = createDefaultPartition(loc, rewriter, team, nShape);
621-
auto lOffs = lPart.getLOffsets();
622-
auto lShape = lPart.getLShape();
619+
// FIXME: Check return type: Check that static sizes are the same as the
620+
// default part sizes
623621

624-
// create output array with target size
625-
auto outArray = rewriter.create<::imex::ndarray::CreateOp>(
626-
loc, lShape, ::imex::ndarray::fromMLIR(elType), nullptr,
627-
getNonDistEnvs(srcDistType));
628-
auto lUMR = ::imex::ndarray::mkURMemRef(loc, rewriter, outArray);
629-
630-
auto idxType = rewriter.getIndexType();
631-
auto gShapeMR = createURMemRefFromElements(rewriter, loc, idxType, gShape);
632-
auto lOffsMR = createURMemRefFromElements(rewriter, loc, idxType, lOffs);
633-
auto nShapePtr = createURMemRefFromElements(rewriter, loc, idxType, nShape);
622+
auto team = dEnv.getTeam();
623+
auto nPart = createDefaultPartition(loc, rewriter, team, ngShape);
624+
auto nlOffs = nPart.getLOffsets();
625+
auto nlShape = nPart.getLShape();
626+
auto shp = getShapeFromValues(nlShape);
627+
auto lRetType = ::imex::ndarray::NDArrayType::get(
628+
shp, elType, getNonDistEnvs(retDistType));
634629

635630
// call the idt runtime
636-
auto fun = rewriter.getStringAttr(mkTypedFunc("_idtr_reshape", elType));
637-
auto teamC = rewriter.create<::mlir::arith::ConstantOp>(
638-
loc, team.cast<::mlir::IntegerAttr>());
639-
(void)rewriter.create<::mlir::func::CallOp>(
640-
loc, fun, ::mlir::TypeRange(),
641-
::mlir::ValueRange{teamC, gShapeMR, lOffsMR, lUMR, nShapePtr});
642-
631+
auto htype = ::imex::distruntime::AsyncHandleType::get(getContext());
632+
auto nlArray = rewriter.create<::imex::distruntime::CopyReshapeOp>(
633+
loc, ::mlir::TypeRange{htype, lRetType}, team, lArray, gShape, lOffs,
634+
ngShape, nlOffs, nlShape);
635+
(void)rewriter.create<::imex::distruntime::WaitOp>(loc,
636+
nlArray.getHandle());
643637
// finally init dist array
644638
rewriter.replaceOp(
645-
op, createDistArray(loc, rewriter, team, nShape, lOffs, {outArray}));
639+
op, createDistArray(loc, rewriter, team, ngShape, nlOffs,
640+
::mlir::ValueRange{nlArray.getNlArray()}));
646641

647642
return ::mlir::success();
648-
#endif
649643
}
650644
};
651645

lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,13 +646,13 @@ struct ReshapeLowering
646646
// check output type and get operands
647647
auto retArTyp = op.getType().dyn_cast<::imex::ndarray::NDArrayType>();
648648
auto srcArTyp =
649-
op.getSrc().getType().dyn_cast<::imex::ndarray::NDArrayType>();
649+
op.getSource().getType().dyn_cast<::imex::ndarray::NDArrayType>();
650650
if (!(retArTyp && srcArTyp)) {
651651
return ::mlir::failure();
652652
}
653653

654654
auto loc = op.getLoc();
655-
auto src = adaptor.getSrc();
655+
auto src = adaptor.getSource();
656656
auto srcTnsr = src.getType().cast<::mlir::TensorType>();
657657
auto shape = adaptor.getShape();
658658
auto elTyp = srcTnsr.getElementType();

lib/Dialect/DistRuntime/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_imex_dialect_library(IMEXDistRuntimeDialect
22
DistRuntimeOps.cpp
33
GetHaloOp.cpp
4+
CopyReshapeOp.cpp
45

56
ADDITIONAL_HEADER_DIRS
67
${PROJECT_SOURCE_DIR}/include/imex/Dialect/DistRuntime
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//===- CopyReshapeOp.cpp - distruntime dialect -----------------*- C++ -*-===//
2+
//
3+
// Copyright 2023 Intel Corporation
4+
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
///
10+
/// \file
11+
/// This file implements the CopyReshapeOp of the DistRuntime dialect.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#include <imex/Dialect/DistRuntime/IR/DistRuntimeOps.h>
16+
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
17+
#include <imex/Utils/PassUtils.h>
18+
19+
namespace imex {
20+
namespace distruntime {
21+
22+
::mlir::SmallVector<::mlir::Value> CopyReshapeOp::getDependent() {
23+
return {getNlArray()};
24+
}
25+
26+
} // namespace distruntime
27+
} // namespace imex
28+
29+
namespace {
30+
31+
/// Pattern to replace dynamically shaped result types
32+
/// by statically shaped result types.
33+
class CopyReshapeOpResultCanonicalizer final
34+
: public mlir::OpRewritePattern<::imex::distruntime::CopyReshapeOp> {
35+
public:
36+
using mlir::OpRewritePattern<
37+
::imex::distruntime::CopyReshapeOp>::OpRewritePattern;
38+
39+
mlir::LogicalResult
40+
matchAndRewrite(::imex::distruntime::CopyReshapeOp op,
41+
::mlir::PatternRewriter &rewriter) const override {
42+
43+
// check input type
44+
auto nlArray = op.getNlArray();
45+
auto nlType = nlArray.getType().dyn_cast<::imex::ndarray::NDArrayType>();
46+
if (!nlType) {
47+
return ::mlir::failure();
48+
}
49+
auto resShape = nlType.getShape();
50+
if (!::mlir::ShapedType::isDynamicShape(resShape)) {
51+
return ::mlir::failure();
52+
}
53+
54+
// we compare result shape with expected shape and bail out if no new
55+
// static info found
56+
auto nlShape = ::imex::getShapeFromValues(op.getNlShape());
57+
bool found = false;
58+
for (auto i = 0u; i < nlShape.size(); ++i) {
59+
if (::mlir::ShapedType::isDynamic(resShape[i]) &&
60+
!::mlir::ShapedType::isDynamic(nlShape[i])) {
61+
found = true;
62+
}
63+
}
64+
if (!found) {
65+
return ::mlir::failure();
66+
}
67+
68+
auto elType = nlType.getElementType();
69+
auto nType = nlType.cloneWith(nlShape, elType);
70+
auto hType = ::imex::distruntime::AsyncHandleType::get(getContext());
71+
72+
auto newOp = rewriter.create<::imex::distruntime::CopyReshapeOp>(
73+
op.getLoc(), ::mlir::TypeRange{hType, nType}, op.getTeamAttr(),
74+
op.getLArray(), op.getGShape(), op.getLOffsets(), op.getNgShape(),
75+
op.getNlOffsets(), op.getNlShape());
76+
77+
// cast to original types and replace op
78+
auto res = rewriter.create<imex::ndarray::CastOp>(op.getLoc(), nlType,
79+
newOp.getNlArray());
80+
rewriter.replaceOp(op, {newOp.getHandle(), res});
81+
82+
return ::mlir::success();
83+
}
84+
};
85+
86+
} // namespace
87+
88+
void imex::distruntime::CopyReshapeOp::getCanonicalizationPatterns(
89+
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
90+
results.add<CopyReshapeOpResultCanonicalizer>(context);
91+
}

lib/Dialect/DistRuntime/IR/GetHaloOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- GetHaloOp.cpp - NDArray dialect ----------------------*- C++ -*-===//
1+
//===- GetHaloOp.cpp - distruntime dialect ---------------------*- C++ -*-===//
22
//
33
// Copyright 2023 Intel Corporation
44
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.

lib/Dialect/DistRuntime/Transforms/DistRuntimeToIDTR.cpp

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,20 @@ struct RuntimePrototypes {
9696
requireFunc(loc, builder, module, "_idtr_prank", {i64Type}, {indexType});
9797
requireFunc(loc, builder, module, "_idtr_reduce_all", {dataMRType, opType},
9898
{});
99-
requireFunc(loc, builder, module, "_idtr_reshape",
100-
// team, gshape, loffs, ownOutpart, nshape
101-
{i64Type, idxMRType, idxMRType, dataMRType, idxMRType}, {});
99+
requireFunc(loc, builder, module, "_idtr_copy_reshape",
100+
// team, gshape, loffs, lPart, ngshape, nloffs, nPart
101+
{i64Type, idxMRType, idxMRType, dataMRType, idxMRType,
102+
idxMRType, dataMRType},
103+
{i64Type});
102104
requireFunc(loc, builder, module, "_idtr_update_halo",
103105
// team, gshape, loffs, lPart, bbOffset, bbShape,
104106
// lHalo, rHalo, key
105107
{i64Type, idxMRType, idxMRType, dataMRType, idxMRType,
106108
idxMRType, dataMRType, dataMRType, i64Type},
107109
{i64Type});
108110
requireFunc(loc, builder, module, "_idtr_wait",
109-
// handle, lHalo, rHalo
110-
{i64Type, dataMRType, dataMRType}, {});
111+
// handle
112+
{i64Type}, {});
111113
}
112114
};
113115

@@ -151,6 +153,58 @@ struct TeamMemberOpPattern
151153
}
152154
};
153155

156+
struct CopyReshapeOpPattern
157+
: public ::mlir::OpRewritePattern<::imex::distruntime::CopyReshapeOp> {
158+
using ::mlir::OpRewritePattern<
159+
::imex::distruntime::CopyReshapeOp>::OpRewritePattern;
160+
161+
::mlir::LogicalResult
162+
matchAndRewrite(::imex::distruntime::CopyReshapeOp op,
163+
::mlir::PatternRewriter &rewriter) const override {
164+
auto lArray = op.getLArray();
165+
auto arType = lArray.getType().dyn_cast<::imex::ndarray::NDArrayType>();
166+
auto resType =
167+
op.getNlArray().getType().dyn_cast<::imex::ndarray::NDArrayType>();
168+
if (!arType || !resType) {
169+
return ::mlir::failure();
170+
}
171+
172+
auto loc = op.getLoc();
173+
auto elType = resType.getElementType();
174+
auto team = op.getTeam();
175+
auto gShape = op.getGShape();
176+
auto lOffs = op.getLOffsets();
177+
auto ngShape = op.getNgShape();
178+
auto nlOffs = op.getNlOffsets();
179+
auto nlShape = op.getNlShape();
180+
181+
// create output array with target size
182+
auto nlArray = rewriter.create<::imex::ndarray::CreateOp>(
183+
loc, nlShape, ::imex::ndarray::fromMLIR(elType), nullptr,
184+
resType.getEnvironments());
185+
186+
auto idxType = rewriter.getIndexType();
187+
auto teamC = rewriter.create<::mlir::arith::ConstantOp>(
188+
loc, team.cast<::mlir::IntegerAttr>());
189+
auto gShapeMR = createURMemRefFromElements(rewriter, loc, idxType, gShape);
190+
auto lOffsMR = createURMemRefFromElements(rewriter, loc, idxType, lOffs);
191+
auto lArrayMR = ::imex::ndarray::mkURMemRef(loc, rewriter, lArray);
192+
auto ngShapeMR =
193+
createURMemRefFromElements(rewriter, loc, idxType, ngShape);
194+
auto nlOffsMR = createURMemRefFromElements(rewriter, loc, idxType, nlOffs);
195+
auto nlArrayMR = ::imex::ndarray::mkURMemRef(loc, rewriter, nlArray);
196+
197+
auto fun =
198+
rewriter.getStringAttr(mkTypedFunc("_idtr_copy_reshape", elType));
199+
auto handle = rewriter.create<::mlir::func::CallOp>(
200+
loc, fun, rewriter.getI64Type(),
201+
::mlir::ValueRange{teamC, gShapeMR, lOffsMR, lArrayMR, ngShapeMR,
202+
nlOffsMR, nlArrayMR});
203+
rewriter.replaceOp(op, {handle.getResult(0), nlArray});
204+
return ::mlir::success();
205+
}
206+
};
207+
154208
/// @brief lower GetHaloOp
155209
/// Determine sizes of halos, alloc halos and call idtr.
156210
/// Before accessing/reading from returned halos, the caller must
@@ -291,24 +345,9 @@ struct WaitOpPattern
291345
::mlir::LogicalResult
292346
matchAndRewrite(::imex::distruntime::WaitOp op,
293347
::mlir::PatternRewriter &rewriter) const override {
294-
auto loc = op.getLoc();
295-
auto handle = op.getHandle();
296-
auto uhOp = handle.getDefiningOp<::imex::distruntime::GetHaloOp>();
297-
assert(uhOp);
298-
auto lHalo = uhOp.getLHalo();
299-
auto rHalo = uhOp.getRHalo();
300-
301-
auto arTyp = lHalo.getType().dyn_cast<::imex::ndarray::NDArrayType>();
302-
if (!arTyp)
303-
return ::mlir::failure();
304-
auto elTyp = arTyp.getElementType();
305-
auto lHaloMR = ::imex::ndarray::mkURMemRef(loc, rewriter, lHalo);
306-
auto rHaloMR = ::imex::ndarray::mkURMemRef(loc, rewriter, rHalo);
307-
308-
auto fsa = rewriter.getStringAttr(mkTypedFunc("_idtr_wait", elTyp));
348+
auto fsa = rewriter.getStringAttr("_idtr_wait");
309349
rewriter.replaceOpWithNewOp<::mlir::func::CallOp>(
310-
op, fsa, ::mlir::TypeRange(),
311-
::mlir::ValueRange{handle, lHaloMR, rHaloMR});
350+
op, fsa, ::mlir::TypeRange(), ::mlir::ValueRange{op.getHandle()});
312351

313352
return ::mlir::success();
314353
}
@@ -325,8 +364,9 @@ struct DistRuntimeToIDTRPass
325364
RuntimePrototypes::add_prototypes(builder, this->getOperation());
326365

327366
::mlir::FrozenRewritePatternSet patterns;
328-
insertPatterns<TeamSizeOpPattern, TeamMemberOpPattern, GetHaloOpPattern,
329-
AllReduceOpPattern, WaitOpPattern>(getContext(), patterns);
367+
insertPatterns<CopyReshapeOpPattern, TeamSizeOpPattern, TeamMemberOpPattern,
368+
GetHaloOpPattern, AllReduceOpPattern, WaitOpPattern>(
369+
getContext(), patterns);
330370
(void)::mlir::applyPatternsAndFoldGreedily(this->getOperation(), patterns);
331371
}; // runOnOperation()
332372

test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ add_lit_testsuite(check-static "Running the IMEX regression tests"
8282
set_target_properties(check-static PROPERTIES FOLDER "Tests")
8383

8484
add_custom_target(check-imex
85-
DEPENDS check-gen check-static
85+
DEPENDS check-static
8686
)
8787

8888
add_lit_testsuites(IMEX ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${IMEX_TEST_DEPENDS})

0 commit comments

Comments
 (0)