Skip to content

Commit f8dc27b

Browse files
committed
Workaround
1 parent ccba545 commit f8dc27b

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

lib/AnalysisStructured/PtrAnalysis.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ LogicalResult PtrState::mulState(const PtrState &lhsState,
231231
}
232232

233233
if (lhsState.scalar && rhsState.scalar) {
234-
scalar = builder.create<arith::MulIOp>(
235-
loc, lhsState.scalar, rhsState.scalar);
234+
scalar =
235+
builder.create<arith::MulIOp>(loc, lhsState.scalar, rhsState.scalar);
236236
}
237237

238238
for (uint64_t i = 0; i < lhs->sizes.size(); i++) {
@@ -687,9 +687,15 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state,
687687
} else if (auto makeTensorOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
688688
llvm_unreachable("Unexpected operand defining operation tts.make_tptr");
689689
} else {
690+
op->dump();
691+
return failure();
690692
llvm_unreachable("Unexpected operand defining operation");
691693
}
692694
} else {
695+
OpBuilder::InsertionGuard guard(builder);
696+
builder.setInsertionPointToStart(operand.getParentBlock());
697+
state.scalar =
698+
builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
693699
state.source = operand;
694700
return success();
695701
}

lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,18 @@ class TritonToUnstructuredPass
269269
}
270270
});
271271

272+
getOperation().walk([&](triton::IntToPtrOp op) {
273+
auto res = op.getResult();
274+
OpBuilder b(op);
275+
Value zero = b.create<arith::ConstantOp>(
276+
op.getLoc(),
277+
b.getIntegerAttr(IntegerType::get(&getContext(), defaultBitWidth),
278+
0));
279+
280+
offsetMap.insert({res, {res, res.getType(), defaultBitWidth, zero}});
281+
workList.push(res);
282+
});
283+
272284
llvm::SmallVector<Operation *> toDelete;
273285
llvm::SmallVector<Operation *> ptrUsers;
274286

@@ -350,6 +362,37 @@ class TritonToUnstructuredPass
350362

351363
return success();
352364
})
365+
.Case<triton::ExpandDimsOp>([&](ExpandDimsOp op) {
366+
auto res = op->getResult(0);
367+
auto resType = res.getType();
368+
369+
if (!isPtrTypeLike(resType)) {
370+
return success();
371+
}
372+
373+
auto ptr = op->getOperand(0);
374+
auto offsetInfo = offsetMap.at(ptr);
375+
376+
OpBuilder b{op};
377+
auto clone =
378+
b.create(op->getLoc(), op->getName().getIdentifier(),
379+
ValueRange{offsetInfo.offset},
380+
TypeRange{getPtrOffsetType(
381+
resType, offsetInfo.bitWidth)});
382+
383+
PtrOffset newOffsetInfo{offsetInfo.ptr, resType,
384+
offsetInfo.bitWidth,
385+
clone->getResult(0)};
386+
387+
offsetMap.insert({
388+
res,
389+
newOffsetInfo,
390+
});
391+
workList.push(res);
392+
toDelete.push_back(op);
393+
394+
return success();
395+
})
353396
.Case<triton::LoadOp, triton::StoreOp, triton::MakeTensorPtrOp,
354397
tts::MakeTensorPtrOp>([&](Operation *op) {
355398
// Special case:

0 commit comments

Comments
 (0)