@@ -214,16 +214,16 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
214214 if (rank == 3 )
215215 return rewriter.notifyMatchFailure (scaledDotOp, " NYI: 3d case" );
216216
217- TensorValue a = scaledDotOp.getLhs ();
218- TensorValue b = scaledDotOp.getRhs ();
219- TensorValue aScale = scaledDotOp.getLhsScale ();
220- TensorValue bScale = scaledDotOp.getRhsScale ();
217+ TensorValue a = scaledDotOp.getA ();
218+ TensorValue b = scaledDotOp.getB ();
219+ TensorValue aScale = scaledDotOp.getAScale ();
220+ TensorValue bScale = scaledDotOp.getBScale ();
221221 if (aScale && bScale)
222222 return rewriter.notifyMatchFailure (scaledDotOp,
223223 " NYI: both LHS and RHS scale" );
224224
225- tt::ScaleDotElemType aElemType = scaledDotOp.getLhsType ();
226- tt::ScaleDotElemType bElemType = scaledDotOp.getRhsType ();
225+ tt::ScaleDotElemType aElemType = scaledDotOp.getAElemType ();
226+ tt::ScaleDotElemType bElemType = scaledDotOp.getBElemType ();
227227 auto supportsTypes = [](tt::ScaleDotElemType elemType) {
228228 return elemType == tt::ScaleDotElemType::E2M1 ||
229229 elemType == tt::ScaleDotElemType::E4M3 ||
@@ -363,10 +363,10 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
363363 getDPASEncoding (tt::DotScaledOp scaledDotOp,
364364 PatternRewriter &rewriter) const {
365365 auto mod = scaledDotOp->getParentOfType <ModuleOp>();
366- TensorValue a = scaledDotOp.getLhs ();
367- TensorValue b = scaledDotOp.getRhs ();
368- TensorValue aScale = scaledDotOp.getLhsScale ();
369- TensorValue bScale = scaledDotOp.getRhsScale ();
366+ TensorValue a = scaledDotOp.getA ();
367+ TensorValue b = scaledDotOp.getB ();
368+ TensorValue aScale = scaledDotOp.getAScale ();
369+ TensorValue bScale = scaledDotOp.getBScale ();
370370 assert ((!aScale || !bScale) && " NYI: both LHS and RHS scale" );
371371
372372 Type elemType =
@@ -613,22 +613,22 @@ static void sinkTransposeOp(tt::TransOp input) {
613613}
614614
615615static tt::TransOp transposeDotOp (tt::DotScaledOp dotOp) {
616- assert (dotOp.getLhsScale () == nullptr && dotOp.getRhsScale () != nullptr &&
616+ assert (dotOp.getAScale () == nullptr && dotOp.getBScale () != nullptr &&
617617 " Transpose DotOp expects scale on RHS" );
618618 OpBuilder builder (dotOp);
619- Value lhs = dotOp.getLhs ();
619+ Value lhs = dotOp.getA ();
620620 std::array<int , 2 > transOrder = {1 , 0 };
621621 auto lhsTransposed =
622622 builder.create <tt::TransOp>(lhs.getLoc (), lhs, transOrder);
623- Value rhs = dotOp.getRhs ();
623+ Value rhs = dotOp.getB ();
624624 auto rhsTransposed =
625625 builder.create <tt::TransOp>(rhs.getLoc (), rhs, transOrder);
626626 Value c = dotOp.getC ();
627627 auto cTransposed = builder.create <tt::TransOp>(c.getLoc (), c, transOrder);
628628 auto result = builder.create <tt::DotScaledOp>(
629629 dotOp.getLoc (), cTransposed.getType (), rhsTransposed, lhsTransposed,
630- cTransposed, dotOp.getRhsScale (), dotOp.getLhsScale (), dotOp.getRhsType (),
631- dotOp.getLhsType (), dotOp.getFastMath ());
630+ cTransposed, dotOp.getBScale (), dotOp.getAScale (), dotOp.getBElemType (),
631+ dotOp.getAElemType (), dotOp.getFastMath ());
632632 auto transOp =
633633 builder.create <tt::TransOp>(result.getLoc (), result, transOrder);
634634 dotOp.replaceAllUsesWith (transOp.getOperation ());
@@ -639,7 +639,7 @@ static tt::TransOp transposeDotOp(tt::DotScaledOp dotOp) {
639639static void transposeDots (ModuleOp m) {
640640 SmallVector<tt::DotScaledOp> toTranspose;
641641 m.walk ([&](tt::DotScaledOp dotOp) -> void {
642- if (dotOp.getLhsScale () == nullptr && dotOp.getRhsScale () != nullptr )
642+ if (dotOp.getAScale () == nullptr && dotOp.getBScale () != nullptr )
643643 toTranspose.push_back (dotOp);
644644 });
645645 SmallVector<tt::TransOp> transposes;
0 commit comments