@@ -49,19 +49,37 @@ SmallVector<ProducedValueInfo> getProducedValues(Operation *op, Block *loopBody,
49
49
return producedValues;
50
50
};
51
51
52
+ template <typename AllocOp, typename LoadOp>
53
+ std::optional<std::pair<AllocOp, LoadOp>> isLoadAndAlloc (Value result) {
54
+ auto alloc = result.getDefiningOp <AllocOp>();
55
+ if (!alloc)
56
+ return std::nullopt;
57
+ if (auto load = alloc.getSrc ().template getDefiningOp <LoadOp>()) {
58
+ return std::make_pair (alloc, load);
59
+ }
60
+ return std::nullopt;
61
+ }
62
+
63
+ // if result is defined by descriptor_load followed by alloc, return the alloc
64
+ // and the load ops as a pair.
65
+ template <typename AllocOp> auto isDescLoadAndAlloc (Value result) {
66
+ return isLoadAndAlloc<AllocOp, triton::DescriptorOpInterface>(result);
67
+ }
68
+
69
+ template <typename AllocOp> auto isGlobalLoadAndAlloc (Value result) {
70
+ return isLoadAndAlloc<AllocOp, triton::LoadOp>(result);
71
+ }
72
+
52
73
ArefCreateOp createAref (OpBuilder &builder, ProducedValueInfo &producedValue) {
53
74
auto result = producedValue.result ;
54
- MemDescType arefBufType;
55
75
56
- if (auto memDescType = dyn_cast<MemDescType>(result.getType ())) {
57
- arefBufType = getMultiBufferedType (memDescType, 1 );
58
- } else if (auto tensorType = dyn_cast<RankedTensorType>(result.getType ())) {
59
- // if result is a value, create memdesctype for location where value will
60
- // be stored
76
+ auto getSmemDescType = [](Value tensorResult) {
77
+ auto tensorType = cast<RankedTensorType>(tensorResult.getType ());
61
78
MemDescType memDescType;
62
79
Attribute SharedMemorySpace =
63
80
SharedMemorySpaceAttr::get (tensorType.getContext ());
64
- if (auto load = result.getDefiningOp <triton::DescriptorOpInterface>()) {
81
+ if (auto load =
82
+ tensorResult.getDefiningOp <triton::DescriptorOpInterface>()) {
65
83
// A use of TMA which is not immediately consumed by LocalAlloc
66
84
// This case applies, for example, when TMA is followed by SIMT ops
67
85
// or MMAv2 is used.
@@ -73,15 +91,25 @@ ArefCreateOp createAref(OpBuilder &builder, ProducedValueInfo &producedValue) {
73
91
} else {
74
92
llvm_unreachable (" Only TMA is expected for now." );
75
93
}
76
- arefBufType = getMultiBufferedType (memDescType, 1 );
94
+ return memDescType;
95
+ };
96
+
97
+ MemDescType memDescType;
98
+ if (isDescLoadAndAlloc<LocalAllocOp>(result)) {
99
+ memDescType = dyn_cast<MemDescType>(result.getType ());
100
+ } else if (auto opt = isDescLoadAndAlloc<TMEMAllocOp>(result)) {
101
+ auto descLoadResult = opt->first .getSrc ();
102
+ memDescType = getSmemDescType (descLoadResult);
103
+ } else if (isa<RankedTensorType>(result.getType ())) {
104
+ memDescType = getSmemDescType (result);
77
105
} else {
78
- std::string msg = " unsupported produced value type: " +
106
+ std::string msg = " createAref: unsupported produced value type: " +
79
107
mlir::debugString (result.getType ());
80
108
llvm::report_fatal_error (msg.c_str ());
81
109
}
82
110
83
- assert ( arefBufType &&
84
- (isa<SharedMemorySpaceAttr>(arefBufType.getMemorySpace () )));
111
+ MemDescType arefBufType = getMultiBufferedType (memDescType, 1 );
112
+ assert (isa<SharedMemorySpaceAttr>(arefBufType.getMemorySpace ()));
85
113
auto loc = result.getLoc ();
86
114
auto alloc = triton::nvws::createAlloc (builder, loc, arefBufType, Value ());
87
115
return createArefCreateOp (builder, {arefBufType}, {alloc->getResult (0 )}, loc);
@@ -127,26 +155,15 @@ void createNVWSDescriptorLoadOp(OpBuilder &builder, Operation *ttDescLoadOp,
127
155
}
128
156
}
129
157
130
- bool isDescLoadAndAlloc (Value result) {
131
- auto alloc = result.getDefiningOp <LocalAllocOp>();
132
- if (!alloc)
133
- return false ;
134
- return alloc.getSrc ().getDefiningOp <triton::DescriptorOpInterface>();
135
- }
136
-
137
- bool isGlobalLoadAndAlloc (Value result) {
138
- auto alloc = result.getDefiningOp <LocalAllocOp>();
139
- if (!alloc)
140
- return false ;
141
- return alloc.getSrc ().getDefiningOp <triton::LoadOp>();
142
- }
143
-
144
158
StageCluster getStageClusterForProducer (Value producedValue) {
145
- if (isDescLoadAndAlloc (producedValue) ||
146
- isGlobalLoadAndAlloc (producedValue)) {
147
- auto alloc = producedValue.getDefiningOp <LocalAllocOp>();
148
- auto loadOp = alloc.getSrc ().getDefiningOp ();
149
- return getStageCluster (loadOp);
159
+ if (auto opt = isDescLoadAndAlloc<LocalAllocOp>(producedValue)) {
160
+ return getStageCluster (opt->second );
161
+ } else if (auto opt = isDescLoadAndAlloc<TMEMAllocOp>(producedValue)) {
162
+ return getStageCluster (opt->second );
163
+ } else if (auto opt = isGlobalLoadAndAlloc<LocalAllocOp>(producedValue)) {
164
+ return getStageCluster (opt->second );
165
+ } else if (auto opt = isGlobalLoadAndAlloc<TMEMAllocOp>(producedValue)) {
166
+ return getStageCluster (opt->second );
150
167
}
151
168
return getStageCluster (producedValue.getDefiningOp ());
152
169
}
@@ -173,15 +190,21 @@ SmallVector<Operation *> createArefPut(PartitionBuilder &builder,
173
190
174
191
auto producerKind = AsyncOp::NONE;
175
192
SmallVector<Operation *> staleOps;
176
- if (isDescLoadAndAlloc (result)) {
177
- auto alloc = result.getDefiningOp <LocalAllocOp>();
178
- auto descOp = alloc.getSrc ().getDefiningOp ();
193
+ if (auto opt = isDescLoadAndAlloc<LocalAllocOp>(result)) {
194
+ auto [alloc, descOp] = *opt;
179
195
createNVWSDescriptorLoadOp (builder, descOp, dataBuf, producerPartition,
180
196
schedule, loc);
181
197
producerKind = AsyncOp::TMALoad;
182
198
staleOps.push_back (alloc);
183
199
staleOps.push_back (descOp);
184
- } else if (isGlobalLoadAndAlloc (result)) {
200
+ } else if (auto opt = isDescLoadAndAlloc<TMEMAllocOp>(result)) {
201
+ auto descOp = opt->second ;
202
+ createNVWSDescriptorLoadOp (builder, descOp, dataBuf, producerPartition,
203
+ schedule, loc);
204
+ producerKind = AsyncOp::TMALoad;
205
+ staleOps.push_back (descOp);
206
+ } else if (isGlobalLoadAndAlloc<LocalAllocOp>(result) ||
207
+ isGlobalLoadAndAlloc<TMEMAllocOp>(result)) {
185
208
llvm_unreachable (" cpasync not supported yet" );
186
209
} else if (auto tensorType = dyn_cast<RankedTensorType>(result.getType ())) {
187
210
if (auto descOp = result.getDefiningOp <triton::DescriptorOpInterface>()) {
@@ -197,7 +220,7 @@ SmallVector<Operation *> createArefPut(PartitionBuilder &builder,
197
220
llvm_unreachable (" Aref for values not supported yet" );
198
221
}
199
222
} else {
200
- std::string msg = " unsupported produced value type: " +
223
+ std::string msg = " createArefPut: unsupported produced value type: " +
201
224
mlir::debugString (result.getType ());
202
225
llvm::report_fatal_error (msg.c_str ());
203
226
}
@@ -327,26 +350,34 @@ void createArefGet(PartitionBuilder &builder, scf::ForOp loop,
327
350
Value token = getEnterOp.getToken ();
328
351
329
352
Operation *exitInsertPointAfter = nullptr ;
353
+
354
+ auto replaceUsesWithLocalLoad = [&](Value result, StageCluster stageCluster) {
355
+ auto localLoadOp = builder.createInto <LocalLoadOp>(
356
+ *consumerPartition, stageCluster, result.getType (), dataBuf);
357
+ result.replaceAllUsesWith (localLoadOp.getResult ());
358
+ schedule.insert (consumerPartition, localLoadOp);
359
+ if (consumers.size () == 1 ) {
360
+ // If there is only one consumer and we hit this code path, the empty
361
+ // barrier can be released after local load.
362
+ exitInsertPointAfter = localLoadOp;
363
+ }
364
+ };
365
+
330
366
for (auto result : results) {
331
- if (auto memDescType = dyn_cast<MemDescType>(result.getType ())) {
367
+ if (auto localAlloc = result.getDefiningOp <LocalAllocOp>()) {
368
+ auto memDescType = cast<MemDescType>(result.getType ());
332
369
auto callback = [&](Operation *oldOp, Operation *newOp) {
333
370
assert (schedule.getPartition (oldOp) == consumerPartition);
334
371
schedule.insert (consumerPartition, newOp);
335
372
};
336
- replaceUsesAndPropagateType (builder, result.getDefiningOp (), dataBuf,
337
- callback);
338
- } else if (auto tensorType = dyn_cast<RankedTensorType>(result.getType ())) {
339
- auto localLoadOp = builder.createInto <LocalLoadOp>(
340
- *consumerPartition, stageClusterEnter, tensorType, dataBuf);
341
- result.replaceAllUsesWith (localLoadOp.getResult ());
342
- schedule.insert (consumerPartition, localLoadOp);
343
- if (consumers.size () == 1 ) {
344
- // If there is only one consumer and we hit this code path, the empty
345
- // barrier can be released after local load.
346
- exitInsertPointAfter = localLoadOp;
347
- }
373
+ replaceUsesAndPropagateType (builder, localAlloc, dataBuf, callback);
374
+ } else if (auto tmemAlloc = result.getDefiningOp <TMEMAllocOp>()) {
375
+ builder.setInsertionPoint (tmemAlloc);
376
+ replaceUsesWithLocalLoad (tmemAlloc.getSrc (), stageClusterEnter);
377
+ } else if (isa<RankedTensorType>(result.getType ())) {
378
+ replaceUsesWithLocalLoad (result, stageClusterEnter);
348
379
} else {
349
- std::string msg = " unsupported produced value type: " +
380
+ std::string msg = " createArefGet: unsupported produced value type: " +
350
381
mlir::debugString (result.getType ());
351
382
llvm::report_fatal_error (msg.c_str ());
352
383
}
@@ -384,9 +415,12 @@ bool insertArefs(PartitionBuilder &builder, scf::ForOp loop,
384
415
385
416
processResultUses (producedValue.result );
386
417
387
- if (isDescLoadAndAlloc (producedValue.result )) {
418
+ if (auto opt = isDescLoadAndAlloc<LocalAllocOp> (producedValue.result )) {
388
419
// Process the register use as well
389
- auto alloc = producedValue.result .getDefiningOp <LocalAllocOp>();
420
+ auto alloc = opt->first ;
421
+ processResultUses (alloc.getSrc ());
422
+ } else if (auto opt = isDescLoadAndAlloc<TMEMAllocOp>(producedValue.result )) {
423
+ auto alloc = opt->first ;
390
424
processResultUses (alloc.getSrc ());
391
425
}
392
426
@@ -446,7 +480,8 @@ class NVWSArefInsertion
446
480
return WalkResult::advance ();
447
481
}
448
482
// Only handles load ops for now.
449
- if (isDescLoadAndAlloc (op->getResult (0 )) ||
483
+ if (isDescLoadAndAlloc<LocalAllocOp>(op->getResult (0 )) ||
484
+ isDescLoadAndAlloc<TMEMAllocOp>(op->getResult (0 )) ||
450
485
(allowDescLoadRegUse &&
451
486
(isa<triton::DescriptorOpInterface>(op)))) {
452
487
ops.push_back (op);
@@ -459,7 +494,7 @@ class NVWSArefInsertion
459
494
getProducedValues (op, loop.getBody (), *schedule);
460
495
for (auto producedValue : producedValues) {
461
496
PartitionBuilder builder (op->getLoc (), op);
462
- builder.setInsertionPointAfter (op);
497
+ builder.setInsertionPoint (op);
463
498
if (insertArefs (builder, loop, *schedule, producedValue, arefTag))
464
499
arefTag++;
465
500
}
0 commit comments