@@ -1200,6 +1200,9 @@ struct ReshapeChangeLayout
1200
1200
}
1201
1201
};
1202
1202
1203
+ static constexpr llvm::StringLiteral
1204
+ kContigiousArraysAttr (" plier.contigious_arrays" );
1205
+
1203
1206
struct MakeStridedLayoutPass
1204
1207
: public mlir::PassWrapper<MakeStridedLayoutPass,
1205
1208
mlir::OperationPass<mlir::ModuleOp>> {
@@ -1210,23 +1213,48 @@ void MakeStridedLayoutPass::runOnOperation() {
1210
1213
auto context = &getContext ();
1211
1214
auto mod = getOperation ();
1212
1215
1216
+ mlir::OpBuilder builder (mod);
1217
+ auto loc = builder.getUnknownLoc ();
1218
+ auto attrStr = builder.getStringAttr (kContigiousArraysAttr );
1219
+
1220
+ llvm::SmallVector<bool > contigiousArrayArg;
1221
+
1222
+ auto isContigiousArrayArg = [&](unsigned i) {
1223
+ if (contigiousArrayArg.empty ())
1224
+ return false ;
1225
+
1226
+ assert (i < contigiousArrayArg.size ());
1227
+ return contigiousArrayArg[i];
1228
+ };
1229
+
1213
1230
llvm::SmallVector<mlir::Type> newArgTypes;
1214
1231
llvm::SmallVector<mlir::Type> newResTypes;
1215
1232
llvm::SmallVector<mlir::Value> newOperands;
1216
1233
for (auto func : mod.getOps <mlir::FuncOp>()) {
1217
- mlir::OpBuilder builder (func.body ());
1218
- auto loc = builder.getUnknownLoc ();
1234
+ auto contAttr = func->getAttr (attrStr).dyn_cast_or_null <mlir::ArrayAttr>();
1235
+ if (contAttr) {
1236
+ auto contAttrRange = contAttr.getAsValueRange <mlir::BoolAttr>();
1237
+ contigiousArrayArg.assign (contAttrRange.begin (), contAttrRange.end ());
1238
+ } else {
1239
+ contigiousArrayArg.clear ();
1240
+ }
1241
+
1219
1242
auto funcType = func.getType ();
1220
1243
auto argTypes = funcType.getInputs ();
1221
1244
auto resTypes = funcType.getResults ();
1222
1245
newArgTypes.assign (argTypes.begin (), argTypes.end ());
1223
1246
newResTypes.assign (resTypes.begin (), resTypes.end ());
1224
- bool hasBody = !func.getBody ().empty ();
1247
+ auto &body = func.getBody ();
1248
+ bool hasBody = !body.empty ();
1249
+ if (hasBody)
1250
+ builder.setInsertionPointToStart (&body.front ());
1251
+
1225
1252
for (auto it : llvm::enumerate (argTypes)) {
1226
1253
auto i = static_cast <unsigned >(it.index ());
1227
1254
auto type = it.value ();
1228
1255
auto memrefType = type.dyn_cast <mlir::MemRefType>();
1229
- if (!memrefType || !memrefType.getLayout ().isIdentity ())
1256
+ if (!memrefType || isContigiousArrayArg (i) ||
1257
+ !memrefType.getLayout ().isIdentity ())
1230
1258
continue ;
1231
1259
1232
1260
auto rank = static_cast <unsigned >(memrefType.getRank ());
@@ -1244,7 +1272,7 @@ void MakeStridedLayoutPass::runOnOperation() {
1244
1272
newArgTypes[i] = newMemrefType;
1245
1273
1246
1274
if (hasBody) {
1247
- auto arg = func. getBody () .front ().getArgument (i);
1275
+ auto arg = body .front ().getArgument (i);
1248
1276
arg.setType (newMemrefType);
1249
1277
auto dst =
1250
1278
builder.create <plier::ChangeLayoutOp>(loc, memrefType, arg);
@@ -2116,6 +2144,63 @@ void PostPlierToLinalgPass::runOnOperation() {
2116
2144
(void )mlir::applyPatternsAndFoldGreedily (getOperation (), std::move (patterns));
2117
2145
}
2118
2146
2147
+ template <typename F>
2148
+ static void visitTypeRecursive (mlir::Type type, F &&visitor) {
2149
+ if (auto tupleType = type.dyn_cast <mlir::TupleType>()) {
2150
+ for (auto t : tupleType.getTypes ())
2151
+ visitTypeRecursive (t, std::forward<F>(visitor));
2152
+ } else {
2153
+ visitor (type);
2154
+ }
2155
+ }
2156
+
2157
+ static bool isContigiousArray (mlir::Type type) {
2158
+ auto pyType = type.dyn_cast <plier::PyType>();
2159
+ if (!pyType)
2160
+ return false ;
2161
+
2162
+ auto name = pyType.getName ();
2163
+ auto desc = parseArrayDesc (name);
2164
+ if (!desc)
2165
+ return false ;
2166
+
2167
+ return desc->layout == ArrayLayout::C;
2168
+ }
2169
+
2170
+ struct MarkContigiousArraysPass
2171
+ : public mlir::PassWrapper<MarkContigiousArraysPass,
2172
+ mlir::OperationPass<mlir::FuncOp>> {
2173
+ void runOnOperation () override {
2174
+ auto func = getOperation ();
2175
+ auto funcType = func.getType ();
2176
+
2177
+ mlir::OpBuilder builder (&getContext ());
2178
+ auto attrStr = builder.getStringAttr (kContigiousArraysAttr );
2179
+ if (func->hasAttr (attrStr)) {
2180
+ markAllAnalysesPreserved ();
2181
+ return ;
2182
+ }
2183
+
2184
+ bool needAttr = false ;
2185
+ llvm::SmallVector<bool > result;
2186
+ result.reserve (funcType.getNumInputs ());
2187
+
2188
+ auto visitor = [&](mlir::Type type) {
2189
+ auto res = isContigiousArray (type);
2190
+ result.emplace_back (res);
2191
+ needAttr = needAttr || res;
2192
+ };
2193
+
2194
+ for (auto type : (func.getType ().getInputs ()))
2195
+ visitTypeRecursive (type, visitor);
2196
+
2197
+ if (needAttr)
2198
+ func->setAttr (attrStr, builder.getBoolArrayAttr (result));
2199
+
2200
+ markAllAnalysesPreserved ();
2201
+ }
2202
+ };
2203
+
2119
2204
template <typename Op>
2120
2205
struct ConvertAlloc : public mlir ::OpConversionPattern<Op> {
2121
2206
using mlir::OpConversionPattern<Op>::OpConversionPattern;
@@ -2601,6 +2686,7 @@ struct FixDeallocPlacementPass
2601
2686
void , FixDeallocPlacement> {};
2602
2687
2603
2688
static void populatePlierToLinalgGenPipeline (mlir::OpPassManager &pm) {
2689
+ pm.addNestedPass <mlir::FuncOp>(std::make_unique<MarkContigiousArraysPass>());
2604
2690
pm.addPass (std::make_unique<PlierToLinalgPass>());
2605
2691
pm.addPass (mlir::createCanonicalizerPass ());
2606
2692
pm.addPass (std::make_unique<NumpyCallsLoweringPass>());
0 commit comments