-
Notifications
You must be signed in to change notification settings - Fork 40
[TLX] Add alloc_warp_barrier for multi-thread barrier arrival
#1031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -254,40 +254,69 @@ struct ArriveBarrierOpConversion | |
| LogicalResult | ||
| matchAndRewrite(triton::nvidia_gpu::ArriveBarrierOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| bool isPerThread = op.getPerThread(); | ||
|
|
||
| bool isRemoteBarrier = false; | ||
| if (auto barType = dyn_cast<ttg::MemDescType>(op.getAlloc().getType())) { | ||
| isRemoteBarrier = | ||
| isa<ttng::SharedClusterMemorySpaceAttr>(barType.getMemorySpace()); | ||
| } | ||
|
|
||
| // TODO: Add phase result as needed. | ||
| std::stringstream ptxAsm; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we generate a lit test for the associated IR after lowering? I think we need to inspect the code lowering prepared for LLVM. If I understand correctly the barrier is only half the issue and we also need to ensure the synchronization before it is removed. This is the an example full IR from AutoWS GEMM. Not 1:1 but should contain a similar pattern that we care about (and I have it available) . This is my relevant code for the section that goes TMEM_LOAD -> barrier arrive: As I understand it, this cc: @htyu in case you have looked at this in more detail.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the membar pass is the culprit. It's designed to protect shared memory access with bar.sync to makes sure all threads see exactly same data. Barriers could be an exception.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @njriasan @htyu Thanks for reviewing! I tried removing that Is it because I only enabled per-thread sync, not per buffer?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's fine with neutral perf. Maybe other cases can benefit from this. Can you check if some bar.sync ops go away on the PTX?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The current implementation is the sync before |
||
| ptxAsm << "@$0 mbarrier.arrive.shared::"; | ||
| if (isRemoteBarrier) | ||
| ptxAsm << "cluster"; | ||
| else | ||
| ptxAsm << "cta"; | ||
| ptxAsm << ".b64 _, [$1]"; | ||
| if (op.getCount() > 1) { | ||
| ptxAsm << ", " << op.getCount(); | ||
| } | ||
| ptxAsm << ";"; | ||
|
|
||
| TritonLLVMOpBuilder b(op.getLoc(), rewriter); | ||
| Value id = getThreadId(rewriter, op.getLoc()); | ||
| Value pred = b.icmp_eq(id, b.i32_val(0)); | ||
| if (op.getPred()) | ||
| pred = b.and_(pred, adaptor.getPred()); | ||
| if (isPerThread) { | ||
| // Warp arrive: every thread arrives independently, no leader pattern. | ||
| bool hasPred = !!op.getPred(); | ||
| std::stringstream ptxAsm; | ||
| if (hasPred) { | ||
| ptxAsm << "@$0 "; | ||
| } | ||
| ptxAsm << "mbarrier.arrive.shared::cta.b64 _, [" | ||
| << (hasPred ? "$1" : "$0") << "]"; | ||
| if (op.getCount() > 1) { | ||
| ptxAsm << ", " << op.getCount(); | ||
| } | ||
| ptxAsm << ";"; | ||
|
|
||
| PTXBuilder ptxBuilder; | ||
| SmallVector<PTXBuilder::Operand *, 2> operands = { | ||
| ptxBuilder.newOperand(pred, "b"), | ||
| ptxBuilder.newOperand(adaptor.getAlloc(), "r")}; | ||
| PTXBuilder ptxBuilder; | ||
| SmallVector<PTXBuilder::Operand *, 2> operands; | ||
| if (hasPred) { | ||
| operands.push_back(ptxBuilder.newOperand(adaptor.getPred(), "b")); | ||
| } | ||
| operands.push_back(ptxBuilder.newOperand(adaptor.getAlloc(), "r")); | ||
|
|
||
| auto arriveOp = *ptxBuilder.create<>(ptxAsm.str()); | ||
| arriveOp(operands, /*onlyAttachMLIRArgs=*/true); | ||
| auto voidTy = void_ty(getContext()); | ||
| ptxBuilder.launch(rewriter, op.getLoc(), voidTy); | ||
| auto arriveOp = *ptxBuilder.create<>(ptxAsm.str()); | ||
| arriveOp(operands, /*onlyAttachMLIRArgs=*/true); | ||
| auto voidTy = void_ty(getContext()); | ||
| ptxBuilder.launch(rewriter, op.getLoc(), voidTy); | ||
| } else { | ||
| // Leader pattern: only thread 0 arrives. | ||
| std::stringstream ptxAsm; | ||
| ptxAsm << "@$0 mbarrier.arrive.shared::"; | ||
| if (isRemoteBarrier) | ||
| ptxAsm << "cluster"; | ||
| else | ||
| ptxAsm << "cta"; | ||
| ptxAsm << ".b64 _, [$1]"; | ||
| if (op.getCount() > 1) { | ||
| ptxAsm << ", " << op.getCount(); | ||
| } | ||
| ptxAsm << ";"; | ||
|
|
||
| TritonLLVMOpBuilder b(op.getLoc(), rewriter); | ||
| Value id = getThreadId(rewriter, op.getLoc()); | ||
| Value pred = b.icmp_eq(id, b.i32_val(0)); | ||
| if (op.getPred()) | ||
| pred = b.and_(pred, adaptor.getPred()); | ||
|
|
||
| PTXBuilder ptxBuilder; | ||
| SmallVector<PTXBuilder::Operand *, 2> operands = { | ||
| ptxBuilder.newOperand(pred, "b"), | ||
| ptxBuilder.newOperand(adaptor.getAlloc(), "r")}; | ||
|
|
||
| auto arriveOp = *ptxBuilder.create<>(ptxAsm.str()); | ||
| arriveOp(operands, /*onlyAttachMLIRArgs=*/true); | ||
| auto voidTy = void_ty(getContext()); | ||
| ptxBuilder.launch(rewriter, op.getLoc(), voidTy); | ||
| } | ||
|
|
||
| rewriter.eraseOp(op); | ||
| return success(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure we run an accuracy test as well. I'm not sure if the general case needs to be conservative and ensure with multiple warps there is a risk that the same warp can contribute multiple arrivals. However, this practically shouldn't happen, so worst case we can add a TODO to test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did do some accuracy test (see my test plan), but I am not sure if that is enough. Please let me know if there is anything other than these tests to check on