Skip to content

Commit db7a20d

Browse files
[Intel] Fix build failures from 63cecbd
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 8dc9e23 commit db7a20d

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ DPASAnalysis::getDPASType(OpTy op) {
139139
}
140140

141141
if constexpr (std::is_same_v<OpTy, DotScaledOp>) {
142-
aTy = cast<RankedTensorType>(op.getLhs().getType());
143-
bTy = cast<RankedTensorType>(op.getRhs().getType());
142+
aTy = cast<RankedTensorType>(op.getA().getType());
143+
bTy = cast<RankedTensorType>(op.getB().getType());
144144
aElemTy = aTy.getElementType();
145145
bElemTy = bTy.getElementType();
146146

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,16 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
214214
if (rank == 3)
215215
return rewriter.notifyMatchFailure(scaledDotOp, "NYI: 3d case");
216216

217-
TensorValue a = scaledDotOp.getLhs();
218-
TensorValue b = scaledDotOp.getRhs();
219-
TensorValue aScale = scaledDotOp.getLhsScale();
220-
TensorValue bScale = scaledDotOp.getRhsScale();
217+
TensorValue a = scaledDotOp.getA();
218+
TensorValue b = scaledDotOp.getB();
219+
TensorValue aScale = scaledDotOp.getAScale();
220+
TensorValue bScale = scaledDotOp.getBScale();
221221
if (aScale && bScale)
222222
return rewriter.notifyMatchFailure(scaledDotOp,
223223
"NYI: both LHS and RHS scale");
224224

225-
tt::ScaleDotElemType aElemType = scaledDotOp.getLhsType();
226-
tt::ScaleDotElemType bElemType = scaledDotOp.getRhsType();
225+
tt::ScaleDotElemType aElemType = scaledDotOp.getAElemType();
226+
tt::ScaleDotElemType bElemType = scaledDotOp.getBElemType();
227227
auto supportsTypes = [](tt::ScaleDotElemType elemType) {
228228
return elemType == tt::ScaleDotElemType::E2M1 ||
229229
elemType == tt::ScaleDotElemType::E4M3 ||
@@ -363,10 +363,10 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
363363
getDPASEncoding(tt::DotScaledOp scaledDotOp,
364364
PatternRewriter &rewriter) const {
365365
auto mod = scaledDotOp->getParentOfType<ModuleOp>();
366-
TensorValue a = scaledDotOp.getLhs();
367-
TensorValue b = scaledDotOp.getRhs();
368-
TensorValue aScale = scaledDotOp.getLhsScale();
369-
TensorValue bScale = scaledDotOp.getRhsScale();
366+
TensorValue a = scaledDotOp.getA();
367+
TensorValue b = scaledDotOp.getB();
368+
TensorValue aScale = scaledDotOp.getAScale();
369+
TensorValue bScale = scaledDotOp.getBScale();
370370
assert((!aScale || !bScale) && "NYI: both LHS and RHS scale");
371371

372372
Type elemType =
@@ -613,22 +613,22 @@ static void sinkTransposeOp(tt::TransOp input) {
613613
}
614614

615615
static tt::TransOp transposeDotOp(tt::DotScaledOp dotOp) {
616-
assert(dotOp.getLhsScale() == nullptr && dotOp.getRhsScale() != nullptr &&
616+
assert(dotOp.getAScale() == nullptr && dotOp.getBScale() != nullptr &&
617617
"Transpose DotOp expects scale on RHS");
618618
OpBuilder builder(dotOp);
619-
Value lhs = dotOp.getLhs();
619+
Value lhs = dotOp.getA();
620620
std::array<int, 2> transOrder = {1, 0};
621621
auto lhsTransposed =
622622
builder.create<tt::TransOp>(lhs.getLoc(), lhs, transOrder);
623-
Value rhs = dotOp.getRhs();
623+
Value rhs = dotOp.getB();
624624
auto rhsTransposed =
625625
builder.create<tt::TransOp>(rhs.getLoc(), rhs, transOrder);
626626
Value c = dotOp.getC();
627627
auto cTransposed = builder.create<tt::TransOp>(c.getLoc(), c, transOrder);
628628
auto result = builder.create<tt::DotScaledOp>(
629629
dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed,
630-
cTransposed, dotOp.getRhsScale(), dotOp.getLhsScale(), dotOp.getRhsType(),
631-
dotOp.getLhsType(), dotOp.getFastMath());
630+
cTransposed, dotOp.getBScale(), dotOp.getAScale(), dotOp.getBElemType(),
631+
dotOp.getAElemType(), dotOp.getFastMath());
632632
auto transOp =
633633
builder.create<tt::TransOp>(result.getLoc(), result, transOrder);
634634
dotOp.replaceAllUsesWith(transOp.getOperation());
@@ -639,7 +639,7 @@ static tt::TransOp transposeDotOp(tt::DotScaledOp dotOp) {
639639
static void transposeDots(ModuleOp m) {
640640
SmallVector<tt::DotScaledOp> toTranspose;
641641
m.walk([&](tt::DotScaledOp dotOp) -> void {
642-
if (dotOp.getLhsScale() == nullptr && dotOp.getRhsScale() != nullptr)
642+
if (dotOp.getAScale() == nullptr && dotOp.getBScale() != nullptr)
643643
toTranspose.push_back(dotOp);
644644
});
645645
SmallVector<tt::TransOp> transposes;

0 commit comments

Comments
 (0)