@@ -187,6 +187,9 @@ SmallVector<unsigned> getOrder(SharedEncodingTrait layout,
187187 if (auto paddedEnc = dyn_cast<PaddedSharedEncodingAttr>(layout)) {
188188 return paddedEnc.getOrder ();
189189 }
190+ if (auto linearEnc = dyn_cast<SharedLinearEncodingAttr>(layout)) {
191+ return linearEnc.getOrder ();
192+ }
190193 if (auto sharedLayout = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
191194 if (shape.size () == 1 ) {
192195 return {0 };
@@ -1566,6 +1569,181 @@ void SwizzledSharedEncodingAttr::print(AsmPrinter &printer) const {
15661569 printer << " }>" ;
15671570}
15681571
1572+ // ===----------------------------------------------------------------------===//
1573+ // SharedLinear encoding
1574+ // ===----------------------------------------------------------------------===//
1575+
1576+ LogicalResult
1577+ SharedLinearEncodingAttr::verify (function_ref<InFlightDiagnostic()> emitError,
1578+ LinearLayout linearLayout,
1579+ unsigned layoutAlignment) {
1580+ if (layoutAlignment == 0 || !llvm::isPowerOf2_32 (layoutAlignment)) {
1581+ return emitError () << " alignment must be a positive power of two" ;
1582+ }
1583+ static const auto expectedInDims =
1584+ SmallVector<std::string>({" offset" , " block" });
1585+ for (const auto &[index, dims] : llvm::enumerate (
1586+ llvm::zip (linearLayout.getInDimNames (), expectedInDims))) {
1587+ const auto &[dim, expected] = dims;
1588+ if (dim.str () != expected) {
1589+ return emitError () << " Expected input dimension " << index << " to be '"
1590+ << expected << " '. Got " << dim;
1591+ }
1592+ }
1593+
1594+ for (auto [i, dim] : llvm::enumerate (linearLayout.getOutDimNames ())) {
1595+ if (dim.str () != (" dim" + llvm::Twine (i)).str ()) {
1596+ return emitError ()
1597+ << " Expected output dimensions to be ['dim0', 'dim1', ...]. Got "
1598+ << dim << " at position " << i;
1599+ }
1600+ }
1601+
1602+ SmallVector<StringAttr> outDimNames =
1603+ llvm::to_vector (linearLayout.getOutDimNames ());
1604+ if (outDimNames.empty ()) {
1605+ return emitError ()
1606+ << " SharedLinearEncodingAttr requires at least one output"
1607+ " dimension." ;
1608+ }
1609+
1610+ auto *ctx = outDimNames.front ().getContext ();
1611+ auto kOffset = StringAttr::get (ctx, " offset" );
1612+ auto kBlock = StringAttr::get (ctx, " block" );
1613+
1614+ if (!linearLayout.isSurjective ()) {
1615+ return emitError () << " The layout must be surjective" ;
1616+ }
1617+
1618+ LinearLayout withoutBroadcast =
1619+ linearLayout.removeZeroBasesAlongDim (kOffset ).removeZeroBasesAlongDim (
1620+ kBlock );
1621+ if (!withoutBroadcast.isInvertible ()) {
1622+ return emitError ()
1623+ << " After removing the zero bases the layout must be bijective" ;
1624+ }
1625+
1626+ return success ();
1627+ }
1628+
1629+ void SharedLinearEncodingAttr::print (AsmPrinter &printer) const {
1630+ printer << " <{" ;
1631+ auto layout = getLinearLayout ();
1632+ auto kBlock = StringAttr::get (getContext (), " block" );
1633+ auto kOffset = StringAttr::get (getContext (), " offset" );
1634+ if (layout.getBases ().lookup (kBlock ).empty ()) {
1635+ layout =
1636+ layout.sublayout ({kOffset }, llvm::to_vector (layout.getOutDimNames ()));
1637+ }
1638+ printLinearLayout (printer, layout);
1639+ printer << " }, alignment = " << getAlignment () << " }>" ;
1640+ }
1641+
1642+ Attribute SharedLinearEncodingAttr::parse (AsmParser &parser, Type type) {
1643+ if (parser.parseLess ().failed ())
1644+ return {};
1645+
1646+ DictionaryAttr layoutDictRaw;
1647+ if (parser.parseAttribute (layoutDictRaw).failed ())
1648+ return {};
1649+
1650+ if (layoutDictRaw.get (" alignment" )) {
1651+ parser.emitError (parser.getCurrentLocation ())
1652+ << " alignment must be specified outside of the linear layout braces" ;
1653+ return {};
1654+ }
1655+
1656+ NamedAttrList layoutAttrList (layoutDictRaw.getValue ());
1657+ auto *ctx = parser.getContext ();
1658+ auto kBlock = StringAttr::get (ctx, " block" );
1659+ if (!layoutAttrList.get (kBlock )) {
1660+ layoutAttrList.push_back ({kBlock , ArrayAttr::get (ctx, {})});
1661+ }
1662+
1663+ DictionaryAttr layoutDict = layoutAttrList.getDictionary (ctx);
1664+
1665+ // Parse alignment
1666+ unsigned layoutAlignment;
1667+ if (parser.parseComma ().failed ())
1668+ return {};
1669+ if (parser.parseKeyword (" alignment" ).failed () || parser.parseEqual ().failed ())
1670+ return {};
1671+ if (parser.parseInteger (layoutAlignment).failed ())
1672+ return {};
1673+
1674+ if (parser.parseGreater ().failed ())
1675+ return {};
1676+
1677+ std::vector<std::string> inDimNames = {" offset" , " block" };
1678+ auto maybeLL = parseLinearLayout (layoutDict, parser, inDimNames);
1679+ if (!maybeLL.has_value ())
1680+ return {};
1681+
1682+ // Special case for cleaner errors
1683+ if (layoutDict.get (" alignment" )) {
1684+ parser.emitError (parser.getCurrentLocation ())
1685+ << " alignment must be specified outside of the linear layout braces" ;
1686+ return {};
1687+ }
1688+
1689+ if (layoutDict.size () != 2 ) {
1690+ parser.emitError (parser.getCurrentLocation ())
1691+ << " SharedLinearEncodingAttr must have exactly two attributes: offset "
1692+ " and block" ;
1693+ return {};
1694+ }
1695+
1696+ return parser.getChecked <SharedLinearEncodingAttr>(
1697+ parser.getContext (), std::move (*maybeLL), layoutAlignment);
1698+ }
1699+
1700+ SmallVector<unsigned >
1701+ SharedLinearEncodingAttr::basesPerDim (StringAttr dimName,
1702+ bool skipBroadcast) const {
1703+ auto ll = getLinearLayout ();
1704+ auto rank = ll.getNumOutDims ();
1705+ return basesPerDimImpl (ll.getBases (), dimName, rank, skipBroadcast);
1706+ }
1707+
1708+ SmallVector<unsigned >
1709+ SharedLinearEncodingAttr::orderPerDim (StringAttr dimName,
1710+ ArrayRef<unsigned > defaultOrder) const {
1711+ return orderPerDimImpl (getLinearLayout (), dimName, defaultOrder);
1712+ }
1713+
1714+ SmallVector<unsigned > SharedLinearEncodingAttr::getOrder () const {
1715+ auto ll = getLinearLayout ();
1716+ auto rank = ll.getNumOutDims ();
1717+ SmallVector<unsigned > defaultOrder (rank);
1718+ std::iota (defaultOrder.rbegin (), defaultOrder.rend (), 0 );
1719+ return orderPerDim (StringAttr::get (getContext (), " offset" ), defaultOrder);
1720+ }
1721+
1722+ SmallVector<unsigned > SharedLinearEncodingAttr::getCTAsPerCGA () const {
1723+ return basesPerDim (StringAttr::get (getContext (), " block" ),
1724+ /* skipBroadcast=*/ false );
1725+ }
1726+
1727+ SmallVector<unsigned > SharedLinearEncodingAttr::getCTAOrder () const {
1728+ return orderPerDim (StringAttr::get (getContext (), " block" ), getOrder ());
1729+ }
1730+
1731+ SmallVector<unsigned > SharedLinearEncodingAttr::getCTASplitNum () const {
1732+ return basesPerDim (StringAttr::get (getContext (), " block" ));
1733+ }
1734+
1735+ LinearLayout
1736+ SharedLinearEncodingAttr::toLinearLayout (ArrayRef<int64_t > shape) const {
1737+ auto ll = getLinearLayout ();
1738+ auto outDimNames = llvm::to_vector (ll.getOutDimNames ());
1739+ assert (shape.size () == outDimNames.size ());
1740+ // We don't support automatic broadcasting for shared linear layouts
1741+ for (auto [size, llSize] : llvm::zip (shape, ll.getOutDimSizes ())) {
1742+ assert (size == llSize);
1743+ }
1744+ return ll;
1745+ }
1746+
15691747// ===----------------------------------------------------------------------===//
15701748// PaddedShared encoding
15711749// ===----------------------------------------------------------------------===//
0 commit comments