@@ -656,110 +656,8 @@ class DecomposeScaledBlocked
656656 }
657657};
658658
659- static void updateValueType (Value v, Attribute encoding,
660- ArrayRef<int64_t > shape) {
661- auto tensorType = cast<RankedTensorType>(v.getType ());
662- auto newType =
663- RankedTensorType::get (shape, tensorType.getElementType (), encoding);
664- v.setType (newType);
665- }
666-
667- static TransOp updateUsers (Value result, const SetVector<Operation *> &slice) {
668- TransOp transOp;
669- if (llvm::any_of (result.getUsers (),
670- [&](Operation *user) { return slice.count (user) == 0 ; })) {
671- OpBuilder builder (result.getContext ());
672- builder.setInsertionPointAfterValue (result);
673- transOp =
674- builder.create <TransOp>(result.getLoc (), result, ArrayRef ({1 , 0 }));
675- result.replaceUsesWithIf (transOp.getResult (), [&](OpOperand &operand) {
676- return operand.getOwner () != transOp.getOperation () &&
677- slice.count (operand.getOwner ()) == 0 ;
678- });
679- }
680- return transOp;
681- }
682-
683- // Sync the transpose in the IR, this is done to avoid generating convert layout
684- // when we have a transpose right after a dot as mma layout cannot be propagated
685- // through transpose op. Once we have layouts that can represent transposed MMA
686- // we can remove this transformation.
687- static void sinkTransposeOp (TransOp input) {
688- SmallVector<TransOp> queue = {input};
689- while (!queue.empty ()) {
690- TransOp transOp = queue.back ();
691- Value currentValue = transOp.getResult ();
692- queue.pop_back ();
693- mlir::ForwardSliceOptions options;
694- options.filter = [](Operation *op) {
695- if (op->hasTrait <OpTrait::Elementwise>() && op->getNumOperands () == 1 )
696- return true ;
697- if (isa<scf::YieldOp>(op))
698- return isa<scf::ForOp>(op->getParentOp ());
699- if (isa<ConvertLayoutOp>(op))
700- return true ;
701- return false ;
702- };
703- SetVector<Operation *> slice;
704- mlir::getForwardSlice (currentValue, &slice, options);
705- for (Operation *op : slice) {
706- if (op->hasTrait <OpTrait::Elementwise>()) {
707- // Update users of transpose op.
708- if (op->getOperand (0 ) == transOp.getResult ())
709- op->setOperand (0 , transOp.getOperand ());
710- // Update the type of the result.
711- for (Value result : op->getResults ()) {
712- auto srcType = cast<RankedTensorType>(op->getOperand (0 ).getType ());
713- updateValueType (result, srcType.getEncoding (), srcType.getShape ());
714- updateUsers (result, slice);
715- }
716- continue ;
717- }
718- if (auto cvtOp = dyn_cast<ConvertLayoutOp>(op)) {
719- // Update users of transpose op.
720- if (op->getOperand (0 ) == transOp.getResult ())
721- op->setOperand (0 , transOp.getOperand ());
722- auto resultEncoding = cvtOp.getType ().getEncoding ();
723- auto newDstEncoding = inferSrcEncoding (transOp, resultEncoding);
724- assert (newDstEncoding);
725- auto srcType = cast<RankedTensorType>(cvtOp.getOperand ().getType ());
726- updateValueType (cvtOp.getResult (), newDstEncoding, srcType.getShape ());
727- updateUsers (cvtOp.getResult (), slice);
728- continue ;
729- }
730- assert (isa<scf::YieldOp>(op));
731- auto forOp = dyn_cast<scf::ForOp>(op->getParentOp ());
732- assert (forOp);
733- for (OpOperand &operand : op->getOpOperands ()) {
734- Operation *def = operand.get ().getDefiningOp ();
735- if (def && (slice.count (def)) || def == transOp.getOperation ()) {
736- if (def == transOp.getOperation ())
737- operand.set (transOp.getOperand ());
738- Type newType = operand.get ().getType ();
739- forOp.getResult (operand.getOperandNumber ()).setType (newType);
740- TransOp retTrans =
741- updateUsers (forOp.getResult (operand.getOperandNumber ()), slice);
742- // Recursively try to propagate the new transpose inserted.
743- if (retTrans)
744- queue.push_back (retTrans);
745- forOp.getRegionIterArg (operand.getOperandNumber ()).setType (newType);
746- TransOp argTrans = updateUsers (
747- forOp.getRegionIterArg (operand.getOperandNumber ()), slice);
748- if (argTrans)
749- queue.push_back (argTrans);
750- OpBuilder builder (forOp);
751- OpOperand &init = forOp.getInitsMutable ()[operand.getOperandNumber ()];
752- Value initTranspose = builder.create <TransOp>(
753- forOp.getLoc (), init.get (), ArrayRef ({1 , 0 }));
754- init.set (initTranspose);
755- }
756- }
757- }
758- }
759- }
760-
761659// Transpose scaled_dot ops that have a scale on lhs.
762- static Operation * transposeDotOp (DotScaledOp dotOp) {
660+ static void transposeDotOp (DotScaledOp dotOp) {
763661 OpBuilder builder (dotOp);
764662 Value lhs = dotOp.getLhs ();
765663 std::array<int , 2 > transOrder = {1 , 0 };
@@ -776,7 +674,6 @@ static Operation *transposeDotOp(DotScaledOp dotOp) {
776674 builder.create <TransOp>(result.getLoc (), result, transOrder);
777675 dotOp.replaceAllUsesWith (transposedResult);
778676 dotOp.erase ();
779- return transposedResult;
780677}
781678
782679static void transposeDots (ModuleOp m) {
@@ -787,14 +684,8 @@ static void transposeDots(ModuleOp m) {
787684 if (dotOp.getLhsScale () == nullptr && dotOp.getRhsScale () != nullptr )
788685 toTranspose.push_back (dotOp);
789686 });
790- SmallVector<Operation *> transposes;
791687 for (DotScaledOp dotOp : toTranspose) {
792- Operation *transpose = transposeDotOp (dotOp);
793- transposes.push_back (transpose);
794- }
795-
796- for (Operation *transpose : transposes) {
797- sinkTransposeOp (cast<TransOp>(transpose));
688+ transposeDotOp (dotOp);
798689 }
799690}
800691
0 commit comments