Skip to content

Commit 5b38abc

Browse files
authored
Do not distrubute OpSNegate into OpUDiv (#5823)
We cannot apply the negate to an operand of an OpUDiv instead of it result. This is because the operands of the OpUDiv are interpreted as unsigned. We stop the optimizer from doing that. There were no tests for distributing a negate into OpIMul, OpSDiv, and OpUDiv. Tests are added for all of these. Fixes #5822
1 parent 5c8442f commit 5b38abc

File tree

2 files changed

+218
-29
lines changed

2 files changed

+218
-29
lines changed

source/opt/folding_rules.cpp

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -422,36 +422,37 @@ FoldingRule MergeNegateMulDivArithmetic() {
422422
if (width != 32 && width != 64) return false;
423423

424424
spv::Op opcode = op_inst->opcode();
425-
if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv ||
426-
opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv ||
427-
opcode == spv::Op::OpUDiv) {
428-
std::vector<const analysis::Constant*> op_constants =
429-
const_mgr->GetOperandConstants(op_inst);
430-
// Merge negate into mul or div if one operand is constant.
431-
if (op_constants[0] || op_constants[1]) {
432-
bool zero_is_variable = op_constants[0] == nullptr;
433-
const analysis::Constant* c = ConstInput(op_constants);
434-
uint32_t neg_id = NegateConstant(const_mgr, c);
435-
uint32_t non_const_id = zero_is_variable
436-
? op_inst->GetSingleWordInOperand(0u)
437-
: op_inst->GetSingleWordInOperand(1u);
438-
// Change this instruction to a mul/div.
439-
inst->SetOpcode(op_inst->opcode());
440-
if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
441-
opcode == spv::Op::OpSDiv) {
442-
uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
443-
uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
444-
inst->SetInOperands(
445-
{{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
446-
} else {
447-
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
448-
{SPV_OPERAND_TYPE_ID, {neg_id}}});
449-
}
450-
return true;
451-
}
425+
if (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFDiv &&
426+
opcode != spv::Op::OpIMul && opcode != spv::Op::OpSDiv) {
427+
return false;
452428
}
453429

454-
return false;
430+
std::vector<const analysis::Constant*> op_constants =
431+
const_mgr->GetOperandConstants(op_inst);
432+
// Merge negate into mul or div if one operand is constant.
433+
if (op_constants[0] == nullptr && op_constants[1] == nullptr) {
434+
return false;
435+
}
436+
437+
bool zero_is_variable = op_constants[0] == nullptr;
438+
const analysis::Constant* c = ConstInput(op_constants);
439+
uint32_t neg_id = NegateConstant(const_mgr, c);
440+
uint32_t non_const_id = zero_is_variable
441+
? op_inst->GetSingleWordInOperand(0u)
442+
: op_inst->GetSingleWordInOperand(1u);
443+
// Change this instruction to a mul/div.
444+
inst->SetOpcode(op_inst->opcode());
445+
if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
446+
opcode == spv::Op::OpSDiv) {
447+
uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
448+
uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
449+
inst->SetInOperands(
450+
{{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
451+
} else {
452+
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
453+
{SPV_OPERAND_TYPE_ID, {neg_id}}});
454+
}
455+
return true;
455456
};
456457
}
457458

test/opt/fold_test.cpp

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5940,7 +5940,195 @@ ::testing::Values(
59405940
"%2 = OpFNegate %v2double %v2double_null\n" +
59415941
"OpReturn\n" +
59425942
"OpFunctionEnd",
5943-
2, true)
5943+
2, true),
5944+
// Test case 20: fold snegate with OpIMul.
5945+
// -(x * 2) = x * -2
5946+
InstructionFoldingCase<bool>(
5947+
Header() +
5948+
"; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
5949+
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
5950+
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
5951+
"; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_n2]]\n" +
5952+
"%main = OpFunction %void None %void_func\n" +
5953+
"%main_lab = OpLabel\n" +
5954+
"%var = OpVariable %_ptr_long Function\n" +
5955+
"%2 = OpLoad %long %var\n" +
5956+
"%3 = OpIMul %long %2 %long_2\n" +
5957+
"%4 = OpSNegate %long %3\n" +
5958+
"OpReturn\n" +
5959+
"OpFunctionEnd",
5960+
4, true),
5961+
// Test case 21: fold snegate with OpIMul.
5962+
// -(x * 2) = x * -2
5963+
InstructionFoldingCase<bool>(
5964+
Header() +
5965+
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
5966+
"; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
5967+
"; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" +
5968+
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
5969+
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_n2]]\n" +
5970+
"%main = OpFunction %void None %void_func\n" +
5971+
"%main_lab = OpLabel\n" +
5972+
"%var = OpVariable %_ptr_int Function\n" +
5973+
"%2 = OpLoad %int %var\n" +
5974+
"%3 = OpIMul %int %2 %uint_2\n" +
5975+
"%4 = OpSNegate %int %3\n" +
5976+
"OpReturn\n" +
5977+
"OpFunctionEnd",
5978+
4, true),
5979+
// Test case 22: fold snegate with OpIMul.
5980+
// -(-24 * x) = x * 24
5981+
InstructionFoldingCase<bool>(
5982+
Header() +
5983+
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
5984+
"; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" +
5985+
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
5986+
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_24]]\n" +
5987+
"%main = OpFunction %void None %void_func\n" +
5988+
"%main_lab = OpLabel\n" +
5989+
"%var = OpVariable %_ptr_int Function\n" +
5990+
"%2 = OpLoad %int %var\n" +
5991+
"%3 = OpIMul %int %int_n24 %2\n" +
5992+
"%4 = OpSNegate %int %3\n" +
5993+
"OpReturn\n" +
5994+
"OpFunctionEnd",
5995+
4, true),
5996+
// Test case 23: fold snegate with OpIMul with UINT_MAX
5997+
// -(UINT_MAX * x) = x
5998+
InstructionFoldingCase<bool>(
5999+
Header() +
6000+
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
6001+
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
6002+
"; CHECK: %4 = OpCopyObject [[int]] [[ld]]\n" +
6003+
"%main = OpFunction %void None %void_func\n" +
6004+
"%main_lab = OpLabel\n" +
6005+
"%var = OpVariable %_ptr_int Function\n" +
6006+
"%2 = OpLoad %int %var\n" +
6007+
"%3 = OpIMul %int %uint_max %2\n" +
6008+
"%4 = OpSNegate %int %3\n" +
6009+
"OpReturn\n" +
6010+
"OpFunctionEnd",
6011+
4, true),
6012+
// Test case 24: fold snegate with OpIMul using -INT_MAX
6013+
// -(x * 2147483649u) = x * 2147483647u
6014+
InstructionFoldingCase<bool>(
6015+
Header() +
6016+
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
6017+
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
6018+
"; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" +
6019+
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
6020+
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_2147483647]]\n" +
6021+
"%main = OpFunction %void None %void_func\n" +
6022+
"%main_lab = OpLabel\n" +
6023+
"%var = OpVariable %_ptr_int Function\n" +
6024+
"%2 = OpLoad %int %var\n" +
6025+
"%3 = OpIMul %int %2 %uint_2147483649\n" +
6026+
"%4 = OpSNegate %int %3\n" +
6027+
"OpReturn\n" +
6028+
"OpFunctionEnd",
6029+
4, true),
6030+
// Test case 25: fold snegate with OpSDiv (long).
6031+
// -(x / 2) = x / -2
6032+
InstructionFoldingCase<bool>(
6033+
Header() +
6034+
"; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
6035+
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
6036+
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
6037+
"; CHECK: %4 = OpSDiv [[long]] [[ld]] [[long_n2]]\n" +
6038+
"%main = OpFunction %void None %void_func\n" +
6039+
"%main_lab = OpLabel\n" +
6040+
"%var = OpVariable %_ptr_long Function\n" +
6041+
"%2 = OpLoad %long %var\n" +
6042+
"%3 = OpSDiv %long %2 %long_2\n" +
6043+
"%4 = OpSNegate %long %3\n" +
6044+
"OpReturn\n" +
6045+
"OpFunctionEnd",
6046+
4, true),
6047+
// Test case 26: fold snegate with OpSDiv (int).
6048+
// -(x / 2) = x / -2
6049+
InstructionFoldingCase<bool>(
6050+
Header() +
6051+
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
6052+
"; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
6053+
"; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" +
6054+
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
6055+
"; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_n2]]\n" +
6056+
"%main = OpFunction %void None %void_func\n" +
6057+
"%main_lab = OpLabel\n" +
6058+
"%var = OpVariable %_ptr_int Function\n" +
6059+
"%2 = OpLoad %int %var\n" +
6060+
"%3 = OpSDiv %int %2 %uint_2\n" +
6061+
"%4 = OpSNegate %int %3\n" +
6062+
"OpReturn\n" +
6063+
"OpFunctionEnd",
6064+
4, true),
6065+
// Test case 27: fold snegate with OpSDiv.
6066+
// -(-24 / x) = 24 / x
6067+
InstructionFoldingCase<bool>(
6068+
Header() +
6069+
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
6070+
"; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" +
6071+
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
6072+
"; CHECK: %4 = OpSDiv [[int]] [[int_24]] [[ld]]\n" +
6073+
"%main = OpFunction %void None %void_func\n" +
6074+
"%main_lab = OpLabel\n" +
6075+
"%var = OpVariable %_ptr_int Function\n" +
6076+
"%2 = OpLoad %int %var\n" +
6077+
"%3 = OpSDiv %int %int_n24 %2\n" +
6078+
"%4 = OpSNegate %int %3\n" +
6079+
"OpReturn\n" +
6080+
"OpFunctionEnd",
6081+
4, true),
6082+
// Test case 28: fold snegate with OpSDiv with UINT_MAX
6083+
// -(UINT_MAX / x) = (1 / x)
6084+
InstructionFoldingCase<bool>(
6085+
Header() +
6086+
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
6087+
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
6088+
"; CHECK: [[uint_1:%\\w+]] = OpConstant [[uint]] 1\n" +
6089+
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
6090+
"; CHECK: %4 = OpSDiv [[int]] [[uint_1]] [[ld]]\n" +
6091+
"%main = OpFunction %void None %void_func\n" +
6092+
"%main_lab = OpLabel\n" +
6093+
"%var = OpVariable %_ptr_int Function\n" +
6094+
"%2 = OpLoad %int %var\n" +
6095+
"%3 = OpSDiv %int %uint_max %2\n" +
6096+
"%4 = OpSNegate %int %3\n" +
6097+
"OpReturn\n" +
6098+
"OpFunctionEnd",
6099+
4, true),
6100+
// Test case 29: fold snegate with OpSDiv using -INT_MAX
6101+
// -(x / 2147483647u) = x / 2147483647
6102+
InstructionFoldingCase<bool>(
6103+
Header() +
6104+
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
6105+
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
6106+
"; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" +
6107+
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
6108+
"; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_2147483647]]\n" +
6109+
"%main = OpFunction %void None %void_func\n" +
6110+
"%main_lab = OpLabel\n" +
6111+
"%var = OpVariable %_ptr_int Function\n" +
6112+
"%2 = OpLoad %int %var\n" +
6113+
"%3 = OpSDiv %int %2 %uint_2147483649\n" +
6114+
"%4 = OpSNegate %int %3\n" +
6115+
"OpReturn\n" +
6116+
"OpFunctionEnd",
6117+
4, true),
6118+
// Test case 30: Don't fold snegate int OpUDiv. The operands are interpreted
6119+
// as unsigned, so negating an operand is not the same a negating the
6120+
// result.
6121+
InstructionFoldingCase<bool>(
6122+
Header() +
6123+
"%main = OpFunction %void None %void_func\n" +
6124+
"%main_lab = OpLabel\n" +
6125+
"%var = OpVariable %_ptr_int Function\n" +
6126+
"%2 = OpLoad %int %var\n" +
6127+
"%3 = OpUDiv %int %2 %uint_1\n" +
6128+
"%4 = OpSNegate %int %3\n" +
6129+
"OpReturn\n" +
6130+
"OpFunctionEnd",
6131+
4, false)
59446132
));
59456133

59466134
INSTANTIATE_TEST_SUITE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,

0 commit comments

Comments
 (0)