@@ -671,7 +671,7 @@ void getTransitiveUsers(Value root,
671671void collectAsyncChannels (SmallVector<std::unique_ptr<Channel>> &channels,
672672 triton::FuncOp &funcOp, unsigned numBuffers) {
673673 funcOp.walk ([&](Operation *op) {
674- if (isa<tt::LoadOp, tt::ExperimentalDescriptorLoadOp >(op) ||
674+ if (isa<tt::LoadOp, tt::DescriptorLoadOp >(op) ||
675675 isa<mlir::triton::DotOpInterface>(op)) {
676676 auto producerTaskIds = getAsyncTaskIds (op);
677677 if (producerTaskIds.empty () || producerTaskIds.size () > 1 ) {
@@ -1611,7 +1611,7 @@ DenseMap<Channel *, DenseMap<int, Value>> createToken(
16111611 auto copyOp = copyOpMap.find (channel)->second .first ;
16121612 if (isa<ttg::AsyncCopyGlobalToLocalOp>(copyOp)) {
16131613 tokenLoadType = ttng::TokenLoadType::AsyncLoadOp;
1614- } else if (isa<ExperimentalDescriptorLoadOp >(copyOp)) {
1614+ } else if (isa<DescriptorLoadOp >(copyOp)) {
16151615 tokenLoadType = ttng::TokenLoadType::TMALoadOp;
16161616 } else if (isa<LocalStoreOp>(copyOp)) {
16171617 tokenLoadType = ttng::TokenLoadType::LocalStoreOp;
@@ -1636,7 +1636,7 @@ DenseMap<Channel *, DenseMap<int, Value>> createToken(
16361636 }
16371637
16381638 auto producerOp = it->second .front ()->getSrcOp ();
1639- if (isa<tt::ExperimentalDescriptorLoadOp >(producerOp)) {
1639+ if (isa<tt::DescriptorLoadOp >(producerOp)) {
16401640 Value bAlloc = createBarrierAlloc (funcOp, channel->numBuffers );
16411641 // Channels in the group share the same set of tokens.
16421642 for (auto &c : it->second ) {
@@ -1863,7 +1863,7 @@ createLocalCopy(const DenseMap<Channel *, Value> &bufferMap, Channel *channel,
18631863 return {copy, sharedLoad};
18641864}
18651865
1866- static int getTMALoadSize (tt::ExperimentalDescriptorLoadOp &tmaLoad) {
1866+ static int getTMALoadSize (tt::DescriptorLoadOp &tmaLoad) {
18671867 auto tensorTy = cast<RankedTensorType>(tmaLoad->getResult (0 ).getType ());
18681868 int loadSize = product (tensorTy.getShape ());
18691869 return loadSize * tensorTy.getElementType ().getIntOrFloatBitWidth () / 8 ;
@@ -1921,7 +1921,7 @@ Value getBufferForPipelineStage(OpBuilderWithAsyncTaskIds &builder,
19211921
19221922Operation *
19231923optimizeTMALoads (OpBuilderWithAsyncTaskIds &builder,
1924- SmallVector<tt::ExperimentalDescriptorLoadOp > &tmaLoads,
1924+ SmallVector<tt::DescriptorLoadOp > &tmaLoads,
19251925 SmallVector<Value> &buffers, Value barrierAlloc,
19261926 Value bufferIdx, Value bufferIdxExtract, Value phase,
19271927 Operation *headProducer, Operation *headConsumer) {
@@ -2168,7 +2168,7 @@ void insertAsyncComm(
21682168
21692169 // Insert ProducerCommitOp if producer is LoadOp. For TMA, TMA lowering
21702170 // will handle the ProducerCommit.
2171- if (!isa<tt::ExperimentalDescriptorLoadOp >(headProducer)) {
2171+ if (!isa<tt::DescriptorLoadOp >(headProducer)) {
21722172 builder.setInsertionPointAfter (tailProducer);
21732173 builder.createWithAsyncTaskIds <ttng::ProducerCommitOp>(
21742174 tailProducer->getLoc (), token.second , bufferIdx);
@@ -2178,7 +2178,7 @@ void insertAsyncComm(
21782178 for (auto token : tokens) {
21792179 builder.setAsynTaskIdsFromArray (token.first );
21802180 // Insert ConsumerWaitOp
2181- if (!isa<tt::ExperimentalDescriptorLoadOp >(headProducer)) {
2181+ if (!isa<tt::DescriptorLoadOp >(headProducer)) {
21822182 auto consumerWaitPoint = getSameLevelOp (headProducer, headConsumer);
21832183 builder.setInsertionPoint (consumerWaitPoint);
21842184 builder.createWithAsyncTaskIds <ttng::ConsumerWaitOp>(
@@ -2193,13 +2193,13 @@ void insertAsyncComm(
21932193 consumerReleasePoint->getLoc (), token.second , bufferIdx);
21942194 }
21952195
2196- SmallVector<tt::ExperimentalDescriptorLoadOp > tmaLoads;
2196+ SmallVector<tt::DescriptorLoadOp > tmaLoads;
21972197 SmallVector<Value> buffers;
21982198 DenseMap<Operation *, Operation *> producerCopyMap;
21992199 // Go through all channels in this channel group.
22002200 for (auto &c : kv.second ) {
22012201 if (auto tmaLoad =
2202- dyn_cast<tt::ExperimentalDescriptorLoadOp >(c->getSrcOp ())) {
2202+ dyn_cast<tt::DescriptorLoadOp >(c->getSrcOp ())) {
22032203 tmaLoads.push_back (tmaLoad);
22042204 buffers.push_back (bufferMap.find (c)->second );
22052205 }
@@ -2278,7 +2278,7 @@ void insertAsyncCopy(
22782278
22792279 // No need to create async copy for TMA load which will be handled in
22802280 // insertAsyncComm.
2281- if (isa<tt::ExperimentalDescriptorLoadOp >(srcOp)) {
2281+ if (isa<tt::DescriptorLoadOp >(srcOp)) {
22822282 producerConsumerOps = {srcOp, domininatingChannel->getDstOp ()};
22832283 } else if (isa<triton::LoadOp>(srcOp)) {
22842284 SmallVector<AsyncTaskId> asyncTasksPC = getAsyncTaskIds (srcOp);
0 commit comments