@@ -163,85 +163,6 @@ void LayoutRematerialization::cleanup() {
163163 op->erase ();
164164}
165165
166- // Look ahead to at the transitive uses and see if there is a convert to mma
167- // operations.
168- bool hasConvertToMMATransisitiveUse (Operation *op, Attribute encoding) {
169- SmallVector<Value> queue = {op->getResult (0 )};
170- SetVector<Operation *> forwardSlice;
171- llvm::SmallDenseSet<Value> seen;
172- while (!queue.empty ()) {
173- Value currentValue = queue.back ();
174- queue.pop_back ();
175- getForwardSlice (currentValue, &forwardSlice);
176- for (Operation *op : forwardSlice) {
177- // HACK: Stop propagation if the ReduceOp is using mma layout but is
178- // producing tensor smaller than the layout we would like to propagate.
179- // This is to avoid stepping into the known bug.
180- if (isa<mlir::triton::ReduceOp>(op)) {
181- auto tensorType =
182- dyn_cast<RankedTensorType>(op->getOperand (0 ).getType ());
183- if (tensorType &&
184- isa<NvidiaMmaEncodingAttr>(tensorType.getEncoding ())) {
185- auto mmaInstrShape =
186- cast<NvidiaMmaEncodingAttr>(encoding).getInstrShape ();
187- if (tensorType.getShape ()[tensorType.getRank () - 2 ] <
188- mmaInstrShape[0 ] ||
189- tensorType.getShape ()[tensorType.getRank () - 1 ] <
190- mmaInstrShape[1 ]) {
191- return false ;
192- }
193- }
194- }
195-
196- if (auto convertOp = dyn_cast<ConvertLayoutOp>(op)) {
197- Attribute dstEncoding = convertOp.getType ().getEncoding ();
198- if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(dstEncoding))
199- return (mmaLayout.getVersionMajor () > 1 ) ? true
200- : mmaLayout == encoding;
201- if (isa<triton::gpu::AMDMfmaEncodingAttr,
202- triton::gpu::AMDWmmaEncodingAttr>(dstEncoding))
203- return true ;
204- if (isa<triton::gpu::DotOperandEncodingAttr>(dstEncoding)) {
205- if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(encoding)) {
206- return mmaLayout.getVersionMajor () > 1 ;
207- } else {
208- assert ((mlir::isa<triton::gpu::AMDMfmaEncodingAttr,
209- triton::gpu::AMDWmmaEncodingAttr>(encoding)));
210- return true ;
211- }
212- }
213- }
214- bool isMMAV3 =
215- isa<NvidiaMmaEncodingAttr>(encoding) &&
216- cast<NvidiaMmaEncodingAttr>(encoding).getVersionMajor () == 3 ;
217- if (isMMAV3 && (isa<LocalAllocOp>(op) || isa<LocalStoreOp>(op)))
218- return true ;
219- auto yield = dyn_cast<scf::YieldOp>(op);
220- if (!yield)
221- continue ;
222- if (auto ifOp = dyn_cast<scf::IfOp>(yield->getParentOp ())) {
223- for (OpOperand &operand : yield->getOpOperands ()) {
224- Operation *def = operand.get ().getDefiningOp ();
225- if (def &&
226- (forwardSlice.count (def) || operand.get () == currentValue) &&
227- (seen.insert (operand.get ()).second == true ))
228- queue.push_back (ifOp.getResult (operand.getOperandNumber ()));
229- }
230- }
231- auto forOp = dyn_cast<scf::ForOp>(yield.getOperation ()->getParentOp ());
232- if (!forOp)
233- continue ;
234- for (OpOperand &operand : yield->getOpOperands ()) {
235- Operation *def = operand.get ().getDefiningOp ();
236- if (def && (forwardSlice.count (def) || operand.get () == currentValue) &&
237- (seen.insert (operand.get ()).second == true ))
238- queue.push_back (forOp.getRegionIterArg (operand.getOperandNumber ()));
239- }
240- }
241- }
242- return false ;
243- }
244-
245166// Return true if the op is an op with a layout we don't want to change. We will
246167// propagate the layout starting from anchor ops.
247168bool isLayoutAnchor (Operation *op) {
@@ -262,18 +183,8 @@ bool isLayoutAnchor(Operation *op) {
262183}
263184
264185void LayoutPropagation::initAnchorLayout () {
265- auto maybeAddAnchor = [&](Value v) {
186+ auto addAnchor = [&](Value v) {
266187 if (auto tensorType = dyn_cast<RankedTensorType>(v.getType ())) {
267- // Workaround, don't popagate MMA layout unless there is a convert
268- // back to mma further down to avoid generating reduction with MMA
269- // layout that may have lower performance.
270- // This can be improved with more aggressive backward propagation.
271- if (isa<MmaEncodingTrait>(tensorType.getEncoding ()) &&
272- v.getDefiningOp () &&
273- !hasConvertToMMATransisitiveUse (v.getDefiningOp (),
274- tensorType.getEncoding ())) {
275- return ;
276- }
277188 layouts.insert ({v, LayoutInfo (tensorType.getEncoding ())});
278189 }
279190 };
@@ -282,13 +193,13 @@ void LayoutPropagation::initAnchorLayout() {
282193 // you can pass a tensor with an encoding as an arg, instead of explicitly
283194 // calling tt.load.
284195 for (auto arg : funcOp.getArguments ()) {
285- maybeAddAnchor (arg);
196+ addAnchor (arg);
286197 }
287198
288199 funcOp.walk ([&](Operation *op) {
289200 if (isLayoutAnchor (op)) {
290201 for (auto result : op->getResults ()) {
291- maybeAddAnchor (result);
202+ addAnchor (result);
292203 }
293204 }
294205 });
0 commit comments