Skip to content

Commit 199a45f

Browse files
authored
[XeTile][WgtoSg] Add lowering patterns for arith/math ops. (#1076)
Support for arith/math ops added to WGToSG pass Author: Shani Abarbanel <[email protected]>
1 parent e24b2b6 commit 199a45f

File tree

1 file changed

+72
-36
lines changed

1 file changed

+72
-36
lines changed

lib/Dialect/XeTile/Transforms/WgToSg.cpp

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,24 +1202,62 @@ void populateXeTileWgToSgPatterns(mlir::RewritePatternSet &patterns,
12021202
WGToSGVectorShapeCast, WGToSGVectorMultiDimReductionOp,
12031203
WGToSGLoadGatherOpPattern, WGToSGStoreScatterOpPattern,
12041204
WGToSGVectorCreateMask, UnrealizedConversionCastOpPattern>(patterns.getContext());
1205-
patterns.insert<WGToSGElementWiseOpSameArgAndResultTypePattern<math::ExpOp, 1>,
1206-
WGToSGElementWiseOpSameArgAndResultTypePattern<math::SqrtOp, 1>,
1207-
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::AddFOp, 2>,
1208-
WGToSGArithDifferentResultTypePattern<arith::TruncFOp>,
1209-
WGToSGArithDifferentResultTypePattern<arith::TruncIOp>,
1210-
WGToSGArithDifferentResultTypePattern<arith::ExtFOp>,
1211-
WGToSGArithDifferentResultTypePattern<arith::ExtSIOp>,
1212-
WGToSGArithDifferentResultTypePattern<arith::ExtUIOp>,
1213-
WGToSGArithDifferentResultTypePattern<arith::SIToFPOp>,
1214-
WGToSGArithDifferentResultTypePattern<arith::UIToFPOp>,
1215-
WGToSGArithDifferentResultTypePattern<arith::FPToSIOp>,
1216-
WGToSGArithDifferentResultTypePattern<arith::FPToUIOp>,
1217-
WGToSGArithDifferentResultTypePattern<arith::IndexCastUIOp>,
1218-
WGToSGArithDifferentResultTypePattern<arith::IndexCastOp>,
1219-
WGToSGArithDifferentResultTypePattern<arith::BitcastOp>,
1220-
WGToSGElementWiseOpComparisonOpsPattern<arith::CmpIOp>,
1221-
WGToSGElementWiseOpComparisonOpsPattern<arith::CmpFOp>,
1222-
WGToSGArithConstantOpPattern>(patterns.getContext());
1205+
patterns.insert<
1206+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::ExpOp, 1>,
1207+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::SqrtOp, 1>,
1208+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::AbsFOp, 1>,
1209+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::CosOp, 1>,
1210+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::CoshOp, 1>,
1211+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::AcosOp, 1>,
1212+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::AcoshOp, 1>,
1213+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::SinOp, 1>,
1214+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::SinhOp, 1>,
1215+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::AsinOp, 1>,
1216+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::AsinhOp, 1>,
1217+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::TanOp, 1>,
1218+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::TanhOp, 1>,
1219+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::AtanOp, 1>,
1220+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::Atan2Op, 2>,
1221+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::AtanhOp, 1>,
1222+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::ErfOp, 1>,
1223+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::LogOp, 1>,
1224+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::Log2Op, 1>,
1225+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::FloorOp, 1>,
1226+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::CeilOp, 1>,
1227+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::PowFOp, 2>,
1228+
WGToSGElementWiseOpSameArgAndResultTypePattern<math::RsqrtOp, 1>,
1229+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::NegFOp, 1>,
1230+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::AddFOp, 2>,
1231+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::AddIOp, 2>,
1232+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::SubFOp, 2>,
1233+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::SubIOp, 2>,
1234+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::MulFOp, 2>,
1235+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::MulIOp, 2>,
1236+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::ShLIOp, 2>,
1237+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::ShRSIOp, 2>,
1238+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::ShRUIOp, 2>,
1239+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::DivFOp, 2>,
1240+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::DivSIOp, 2>,
1241+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::DivUIOp, 2>,
1242+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::MaximumFOp, 2>,
1243+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::MinimumFOp, 2>,
1244+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::RemSIOp, 2>,
1245+
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::RemUIOp, 2>,
1246+
WGToSGArithDifferentResultTypePattern<arith::TruncFOp>,
1247+
WGToSGArithDifferentResultTypePattern<arith::TruncIOp>,
1248+
WGToSGArithDifferentResultTypePattern<arith::ExtFOp>,
1249+
WGToSGArithDifferentResultTypePattern<arith::ExtSIOp>,
1250+
WGToSGArithDifferentResultTypePattern<arith::ExtUIOp>,
1251+
WGToSGArithDifferentResultTypePattern<arith::SIToFPOp>,
1252+
WGToSGArithDifferentResultTypePattern<arith::UIToFPOp>,
1253+
WGToSGArithDifferentResultTypePattern<arith::FPToSIOp>,
1254+
WGToSGArithDifferentResultTypePattern<arith::FPToUIOp>,
1255+
WGToSGArithDifferentResultTypePattern<arith::IndexCastUIOp>,
1256+
WGToSGArithDifferentResultTypePattern<arith::IndexCastOp>,
1257+
WGToSGArithDifferentResultTypePattern<arith::BitcastOp>,
1258+
WGToSGElementWiseOpComparisonOpsPattern<arith::CmpIOp>,
1259+
WGToSGElementWiseOpComparisonOpsPattern<arith::CmpFOp>,
1260+
WGToSGArithConstantOpPattern>(patterns.getContext());
12231261
}
12241262

12251263
// Transforms WG XeTile IR to SG XeTile
@@ -1364,14 +1402,6 @@ class XeTileWgToSgPass
13641402
return false;
13651403
});
13661404

1367-
target.addDynamicallyLegalOp<xetile::LoadGatherOp>(
1368-
[&](xetile::LoadGatherOp op) -> bool {
1369-
if (!op.getTile().getType().getWgMap())
1370-
return true;
1371-
else
1372-
return false;
1373-
});
1374-
13751405
target.addDynamicallyLegalOp<xetile::TileMMAOp>(
13761406
[&](xetile::TileMMAOp op) -> bool {
13771407
auto mapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(
@@ -1406,16 +1436,22 @@ class XeTileWgToSgPass
14061436
return false;
14071437
});
14081438

1409-
target.addDynamicallyLegalOp<mlir::arith::ConstantOp, mlir::arith::AddFOp,
1410-
mlir::math::ExpOp, mlir::math::SqrtOp, mlir::arith::ExtFOp,
1411-
mlir::arith::ExtSIOp, mlir::arith::ExtUIOp, mlir::arith::FPToSIOp,
1412-
mlir::arith::FPToUIOp, mlir::arith::UIToFPOp, mlir::arith::SIToFPOp,
1413-
mlir::arith::TruncFOp, mlir::arith::TruncIOp, mlir::arith::CmpIOp,
1414-
mlir::arith::CmpFOp, mlir::arith::IndexCastUIOp, mlir::arith::SelectOp,
1415-
mlir::math::FPowIOp,mlir::arith::IndexCastOp, mlir::arith::BitcastOp,
1416-
mlir::vector::TransposeOp, mlir::vector::BroadcastOp,
1417-
mlir::vector::MultiDimReductionOp,mlir::vector::ShapeCastOp,
1418-
mlir::vector::CreateMaskOp>(
1439+
target.addDynamicallyLegalOp<
1440+
arith::ConstantOp, arith::AddFOp, arith::AddIOp, arith::SubFOp,
1441+
arith::SubIOp, arith::MulFOp, arith::MulIOp, arith::ShLIOp,
1442+
arith::ShRSIOp, arith::ShRUIOp, arith::DivFOp, arith::DivSIOp,
1443+
arith::DivUIOp, arith::MaximumFOp, arith::MinimumFOp, arith::RemSIOp,
1444+
arith::RemUIOp, arith::NegFOp, math::ExpOp, math::SqrtOp, math::AbsFOp,
1445+
math::AcosOp, math::AcoshOp, math::SinOp, math::SinhOp, math::AsinOp,
1446+
math::AsinhOp, math::TanOp, math::TanhOp, math::AtanOp, math::Atan2Op,
1447+
math::AtanhOp, math::CosOp, math::CoshOp, math::ErfOp, math::LogOp,
1448+
math::Log2Op, math::FloorOp, math::CeilOp, math::PowFOp, math::RsqrtOp,
1449+
arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp, arith::FPToSIOp,
1450+
arith::FPToUIOp, arith::UIToFPOp, arith::SIToFPOp, arith::TruncFOp,
1451+
arith::TruncIOp, arith::CmpIOp, arith::CmpFOp, arith::IndexCastUIOp,
1452+
arith::SelectOp, math::FPowIOp, arith::IndexCastOp, arith::BitcastOp,
1453+
vector::TransposeOp, vector::BroadcastOp, vector::MultiDimReductionOp,
1454+
vector::ShapeCastOp, vector::CreateMaskOp>(
14191455
[&](mlir::Operation *op) -> bool {
14201456
auto mapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(
14211457
op->getAttr("map"));

0 commit comments

Comments
 (0)