-
Notifications
You must be signed in to change notification settings - Fork 75
Insert freeze between masked loads and sdiv/srem instructions #2775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
7f50dcf
Prevent UB in div/rem instructions during optimization
alexbaden 1dd55a4
Add regression test 1/?
alexbaden c728635
Parametrize test_divide (2/?)
alexbaden e36f345
fixup format in test_divide
alexbaden e0310dd
LLVM freeze instruction between mask and div 1/?
alexbaden 2243b25
LLVM freeze instruction between mask and div 2/?
alexbaden 6c6a0f0
LLVM freeze instruction between mask and div 3/?
alexbaden 3ef26d0
LLVM freeze instruction between mask and div 4/?
alexbaden a35b0ed
LLVM freeze instruction between mask and div 5/5
alexbaden 302fd39
fixup
alexbaden 9f814ee
Remove unused variable
alexbaden e9723dc
rename processPhiNode -> processBasicBlock
alexbaden 928afd1
simplify phi node incoming values constant check expression
alexbaden 3349109
cleanup formatting in division test
alexbaden f1a6029
add lit test
alexbaden 7ff052b
support multiple phis and undef
alexbaden ed6df23
remove unused libs
alexbaden dc9c16e
address review comments
alexbaden File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # flake8: noqa: F821, F841 | ||
| import torch | ||
| import pytest | ||
|
|
||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| aten = torch.ops.aten | ||
|
|
||
|
|
||
| def patch_kernel(template, to_replace): | ||
| kernel = triton.JITFunction(template.fn) | ||
| for key, value in to_replace.items(): | ||
| kernel.src = kernel.src.replace(key, value) | ||
| return kernel | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("float_div", [True, False]) | ||
| @pytest.mark.parametrize("floor", [True, False]) | ||
| @pytest.mark.parametrize("trunc", [True, False]) | ||
| def test_divide(float_div, floor, trunc, device): | ||
| # regression test for various division cases | ||
|
|
||
| @triton.jit | ||
| def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr): | ||
| xoffset = tl.program_id(0) * XBLOCK | ||
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | ||
| xmask = xindex < xnumel | ||
| x0 = xindex | ||
| tmp0 = tl.load(a + (x0), xmask) | ||
| tmp2 = tl.load(b + (x0), xmask) | ||
| # custom bits | ||
| tmp1 = tmp0.to(tl.float32) | ||
| tmp3 = tmp2.to(tl.float32) | ||
| tmp4 = tmp1 / tmp3 | ||
| tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2), | ||
| tmp0 // tmp2) | ||
| tmp6 = tmp0 // tmp2 | ||
| GENERATE_OUTPUTS_HERE | ||
|
|
||
| torch.manual_seed(0) | ||
|
|
||
| outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div else "" | ||
| outputs_floor = " tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor else "" | ||
| outputs_trunc = " tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc else "" | ||
|
|
||
| divide_kernel = patch_kernel(divide_kernel, | ||
| {"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"}) | ||
|
|
||
| def launch_triton(a, b): | ||
| output0 = torch.zeros_like(a) | ||
| output1 = torch.zeros_like(a) | ||
| output2 = torch.zeros_like(a) | ||
| output3 = torch.zeros_like(a) | ||
| output4 = torch.zeros_like(a) | ||
|
|
||
| n_elements = output0.numel() | ||
|
|
||
| grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), ) | ||
|
|
||
| divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128) | ||
|
|
||
| return (output0, output1, output2, output3, output4) | ||
|
|
||
| def launch_torch(a, b): | ||
| return ( | ||
| aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a), | ||
| aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a), | ||
| aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a), | ||
| a / b if float_div is True else torch.zeros_like(a), | ||
| a // b if floor is True else torch.zeros_like(a), | ||
| ) | ||
|
|
||
| a = torch.randint(2**32, 2**40, [100, 100], device=device) | ||
| b = torch.randint(-10, -1, [100, 100], device=device) | ||
|
|
||
| for iter in range(100): | ||
| triton_result = launch_triton(a, b) | ||
| torch_result = launch_torch(a, b) | ||
|
|
||
| for i in range(5): | ||
| torch.testing.assert_close( | ||
| triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg: | ||
| f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| ; RUN: triton-llvm-opt -freeze-masked-div-rem %s | FileCheck %s | ||
|
|
||
| define void @phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) { | ||
| ; CHECK-LABEL: @phi_div_of_zero_okay( | ||
| entry: | ||
| %cmp = icmp ult i8 %i, 9 | ||
| br i1 %cmp, label %if.then, label %if.end | ||
|
|
||
| if.then: | ||
| %y = load i8, ptr %v, align 8 | ||
| br label %if.end | ||
|
|
||
| if.end: | ||
| %yy = phi i8 [ %y, %if.then ], [ 0, %entry ] | ||
| ; CHECK: [[F0:%.*]] = freeze i8 %yy | ||
| ; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]] | ||
| %z = sdiv i8 %x, %yy | ||
| br i1 %cmp, label %if2.then, label %if2.end | ||
|
|
||
| if2.then: | ||
| store i8 %z, ptr %v, align 8 | ||
| br label %if2.end | ||
|
|
||
| if2.end: | ||
| ret void | ||
| } | ||
|
|
||
| define void @two_phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) { | ||
| ; CHECK-LABEL: @two_phi_div_of_zero_okay( | ||
| entry: | ||
| %cmp = icmp ult i8 %i, 9 | ||
| br i1 %cmp, label %if.then, label %if.end | ||
|
|
||
| if.then: | ||
| %y = load i8, ptr %v, align 8 | ||
| %vv = getelementptr inbounds i64, ptr %v, i64 1 | ||
| %b = load i8, ptr %vv, align 8 | ||
| br label %if.end | ||
|
|
||
| if.end: | ||
| %bb = phi i8 [ %b, %if.then ], [ undef, %entry ] | ||
| %yy = phi i8 [ %y, %if.then ], [ 0, %entry ] | ||
| ; CHECK: [[F0:%.*]] = freeze i8 %yy | ||
| ; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]] | ||
| %z = sdiv i8 %x, %yy | ||
| ; CHECK: [[F1:%.*]] = freeze i8 %bb | ||
| ; CHECK-NEXT: %zz = sdiv i8 %x, [[F1:%.*]] | ||
| %zz = sdiv i8 %x, %bb | ||
| br i1 %cmp, label %if2.then, label %if2.end | ||
|
|
||
| if2.then: | ||
| store i8 %z, ptr %v, align 8 | ||
| br label %if2.end | ||
|
|
||
| if2.end: | ||
| ret void | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| add_triton_library(TritonIntelLLVMIR | ||
| LLVMIRFreezeMaskedDivRem.cpp | ||
|
|
||
| DEPENDS | ||
| LLVMIRIncGen | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| #include "LLVMPasses.h" | ||
| #include "llvm/Analysis/TargetTransformInfo.h" | ||
| #include "llvm/Analysis/ValueTracking.h" | ||
| #include "llvm/IR/Dominators.h" | ||
| #include "llvm/IR/Instructions.h" | ||
|
|
||
| using namespace llvm; | ||
|
|
||
| static bool processPhiNode(PHINode *PhiNode) { | ||
| if (none_of(PhiNode->incoming_values(), [](Use &U) { | ||
| Constant *C = dyn_cast<Constant>(&U); | ||
| return isa<UndefValue>(U) || C && C->isNullValue(); | ||
| })) { | ||
| return false; | ||
| } | ||
|
|
||
| bool Changed = false; | ||
| BasicBlock *BB = const_cast<BasicBlock *>(PhiNode->getParent()); | ||
| for (Instruction &I : *BB) { | ||
| if (I.getOpcode() == Instruction::SDiv || | ||
| I.getOpcode() == Instruction::SRem) { | ||
| const size_t OpIdx = 1; | ||
| if (I.getOperand(OpIdx) == PhiNode) { | ||
| auto *freezePhi = new FreezeInst( | ||
| PhiNode, PhiNode->getName() + ".frozen", I.getIterator()); | ||
| I.setOperand(OpIdx, freezePhi); | ||
| Changed = true; | ||
| } | ||
| } | ||
| } | ||
| return Changed; | ||
| } | ||
|
|
||
| static bool runOnFunction(Function &F) { | ||
| bool Changed = false; | ||
|
|
||
| for (BasicBlock &BB : F) { | ||
| for (PHINode &PhiNode : BB.phis()) { | ||
| Changed |= processPhiNode(&PhiNode); | ||
| } | ||
| } | ||
|
|
||
| return Changed; | ||
| } | ||
|
|
||
| PreservedAnalyses FreezeMaskedDivRemPass::run(Function &F, | ||
| FunctionAnalysisManager &FAM) { | ||
| const auto b = runOnFunction(F); | ||
|
|
||
| return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| #include "llvm/IR/PassManager.h" | ||
| #include "llvm/Pass.h" | ||
|
|
||
| namespace llvm { | ||
|
|
||
| struct FreezeMaskedDivRemPass : PassInfoMixin<FreezeMaskedDivRemPass> { | ||
| PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); | ||
| static StringRef name() { return "FreezeMaskedDivRemPass"; } | ||
| }; | ||
|
|
||
| } // namespace llvm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.