Skip to content

Commit d4a0537

Browse files
Add WGMMAFenceOpPattern
1 parent 7efc694 commit d4a0537

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,8 @@ using ttn::OperandsAndConstraints;
2323

2424
namespace {
2525

26-
const std::string kWgmmaFenceOp = "wgmma.fence.sync.aligned;";
2726
const std::string kWgmmaCommitGroupOp = "wgmma.commit_group.sync.aligned;";
2827
const std::string kClusterWaitOp = "barrier.cluster.wait.aligned;";
29-
const std::string kFenceMbarrierInitOp = "fence.mbarrier_init.release.cluster;";
3028
const std::string kClusterCtaIdOp = "{\n"
3129
".reg .u32 a<5>; \n"
3230
"mov.u32 a0, %cluster_ctaid.x;\n" // x
@@ -210,6 +208,19 @@ class NVGPUOpGenericPattern : public OpRewritePattern<SourceOp> {
210208
Constraints inputConstraints;
211209
};
212210

211+
class WGMMAFenceOpPattern : public OpRewritePattern<ttn::WGMMAFenceOp> {
212+
public:
213+
using OpRewritePattern<ttn::WGMMAFenceOp>::OpRewritePattern;
214+
215+
LogicalResult matchAndRewrite(ttn::WGMMAFenceOp op,
216+
PatternRewriter &rewriter) const override {
217+
auto loc = op.getLoc();
218+
rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
219+
rewriter.eraseOp(op);
220+
return success();
221+
}
222+
};
223+
213224
class FenceAsyncSharedOpPattern
214225
: public OpRewritePattern<ttn::FenceAsyncSharedOp> {
215226
public:
@@ -779,12 +790,16 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase<ConvertNVGPUToLLVM> {
779790
#define POPULATE_NVGPU_OP(SRC_OP, ASM) \
780791
patterns.add<NVGPUOpGenericPattern<SRC_OP>>(context, ASM, Constraints(), \
781792
Constraints());
782-
POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, kWgmmaFenceOp)
793+
// POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, kWgmmaFenceOp)
783794
POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, kWgmmaCommitGroupOp)
784795
POPULATE_NVGPU_OP(ttn::ClusterWaitOp, kClusterWaitOp)
785796
#undef POPULATE_NVGPU_OP
786797
patterns.add<NVGPUOpGenericPattern<ttn::ClusterCTAIdOp>>(
787798
context, kClusterCtaIdOp, Constraints({"=r"}), Constraints());
799+
// patterns.add<WGMMAFenceOpPattern,
800+
// WGMMACommitGroupOpPattern,ClusterWaitOpPattern,
801+
// ClusterCTAIdOpPattern>(context);
802+
patterns.add<WGMMAFenceOpPattern>(context);
788803

789804
patterns
790805
.add<FenceAsyncSharedOpPattern, LoadMatrixOpPattern,

0 commit comments

Comments
 (0)