Skip to content

Commit 296d9d6

Browse files
marc-chevalierchhagedorn
authored andcommitted
8353345: C2 asserts because maskShiftAmount modifies node without deleting the hash
Reviewed-by: chagedorn, thartmann
1 parent 3ceabf0 commit 296d9d6

File tree

2 files changed

+110
-34
lines changed

2 files changed

+110
-34
lines changed

src/hotspot/share/opto/mulnode.cpp

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -950,63 +950,86 @@ LShiftNode* LShiftNode::make(Node* in1, Node* in2, BasicType bt) {
950950
return nullptr;
951951
}
952952

953-
//=============================================================================
954-
955-
static bool const_shift_count(PhaseGVN* phase, Node* shiftNode, int* count) {
956-
const TypeInt* tcount = phase->type(shiftNode->in(2))->isa_int();
953+
// Returns whether the shift amount is constant. If so, sets count.
954+
static bool const_shift_count(PhaseGVN* phase, const Node* shift_node, int* count) {
955+
const TypeInt* tcount = phase->type(shift_node->in(2))->isa_int();
957956
if (tcount != nullptr && tcount->is_con()) {
958957
*count = tcount->get_con();
959958
return true;
960959
}
961960
return false;
962961
}
963962

964-
static int maskShiftAmount(PhaseGVN* phase, Node* shiftNode, uint nBits) {
965-
int count = 0;
966-
if (const_shift_count(phase, shiftNode, &count)) {
967-
int maskedShift = count & (nBits - 1);
968-
if (maskedShift == 0) {
963+
// Returns whether the shift amount is constant. If so, sets real_shift and masked_shift.
964+
static bool mask_shift_amount(PhaseGVN* phase, const Node* shift_node, uint nBits, int& real_shift, int& masked_shift) {
965+
if (const_shift_count(phase, shift_node, &real_shift)) {
966+
masked_shift = real_shift & (nBits - 1);
967+
return true;
968+
}
969+
return false;
970+
}
971+
972+
// Convenience for when we don't care about the real amount
973+
static bool mask_shift_amount(PhaseGVN* phase, const Node* shift_node, uint nBits, int& masked_shift) {
974+
int real_shift;
975+
return mask_shift_amount(phase, shift_node, nBits, real_shift, masked_shift);
976+
}
977+
978+
// Use this in ::Ideal only with shiftNode == this!
979+
// Returns the masked shift amount if constant or 0 if not constant.
980+
static int mask_and_replace_shift_amount(PhaseGVN* phase, Node* shift_node, uint nBits) {
981+
int real_shift;
982+
int masked_shift;
983+
if (mask_shift_amount(phase, shift_node, nBits, real_shift, masked_shift)) {
984+
if (masked_shift == 0) {
969985
// Let Identity() handle 0 shift count.
970986
return 0;
971987
}
972988

973-
if (count != maskedShift) {
974-
shiftNode->set_req(2, phase->intcon(maskedShift)); // Replace shift count with masked value.
989+
if (real_shift != masked_shift) {
975990
PhaseIterGVN* igvn = phase->is_IterGVN();
976-
if (igvn) {
977-
igvn->rehash_node_delayed(shiftNode);
991+
if (igvn != nullptr) {
992+
igvn->_worklist.push(shift_node);
978993
}
994+
shift_node->set_req(2, phase->intcon(masked_shift)); // Replace shift count with masked value.
979995
}
980-
return maskedShift;
996+
return masked_shift;
981997
}
998+
// Not a shift by a constant.
982999
return 0;
9831000
}
9841001

9851002
// Called with
986-
// outer_shift = (_ << con0)
1003+
// outer_shift = (_ << rhs_outer)
9871004
// We are looking for the pattern:
988-
// outer_shift = ((X << con1) << con0)
989-
// we denote inner_shift the nested expression (X << con1)
990-
//
991-
// con0 and con1 are both in [0..nbits), as they are computed by maskShiftAmount.
1005+
// outer_shift = ((X << rhs_inner) << rhs_outer)
1006+
// where rhs_outer and rhs_inner are constant
1007+
// we denote inner_shift the nested expression (X << rhs_inner)
1008+
// con_inner = rhs_inner % nbits and con_outer = rhs_outer % nbits
1009+
// where nbits is the number of bits of the shifts
9921010
//
9931011
// There are 2 cases:
994-
// if con0 + con1 >= nbits => 0
995-
// if con0 + con1 < nbits => X << (con1 + con0)
996-
static Node* collapse_nested_shift_left(PhaseGVN* phase, Node* outer_shift, int con0, BasicType bt) {
1012+
// if con_outer + con_inner >= nbits => 0
1013+
// if con_outer + con_inner < nbits => X << (con_outer + con_inner)
1014+
static Node* collapse_nested_shift_left(PhaseGVN* phase, const Node* outer_shift, int con_outer, BasicType bt) {
9971015
assert(bt == T_LONG || bt == T_INT, "Unexpected type");
998-
int nbits = static_cast<int>(bits_per_java_integer(bt));
999-
Node* inner_shift = outer_shift->in(1);
1016+
const Node* inner_shift = outer_shift->in(1);
10001017
if (inner_shift->Opcode() != Op_LShift(bt)) {
10011018
return nullptr;
10021019
}
10031020

1004-
int con1 = maskShiftAmount(phase, inner_shift, nbits);
1005-
if (con1 == 0) { // Either non-const, or actually 0 (up to mask) and then delegated to Identity()
1021+
int nbits = static_cast<int>(bits_per_java_integer(bt));
1022+
int con_inner;
1023+
if (!mask_shift_amount(phase, inner_shift, nbits, con_inner)) {
1024+
return nullptr;
1025+
}
1026+
1027+
if (con_inner == 0) {
1028+
// We let the Identity() of the inner shift do its job.
10061029
return nullptr;
10071030
}
10081031

1009-
if (con0 + con1 >= nbits) {
1032+
if (con_outer + con_inner >= nbits) {
10101033
// While it might be tempting to use
10111034
// phase->zerocon(bt);
10121035
// it would be incorrect: zerocon caches nodes, while Ideal is only allowed
@@ -1015,7 +1038,7 @@ static Node* collapse_nested_shift_left(PhaseGVN* phase, Node* outer_shift, int
10151038
}
10161039

10171040
// con0 + con1 < nbits ==> actual shift happens now
1018-
Node* con0_plus_con1 = phase->intcon(con0 + con1);
1041+
Node* con0_plus_con1 = phase->intcon(con_outer + con_inner);
10191042
return LShiftNode::make(inner_shift->in(1), con0_plus_con1, bt);
10201043
}
10211044

@@ -1036,7 +1059,7 @@ Node* LShiftINode::Identity(PhaseGVN* phase) {
10361059
// Also collapse nested left-shifts with constant rhs:
10371060
// (X << con1) << con2 ==> X << (con1 + con2)
10381061
Node *LShiftINode::Ideal(PhaseGVN *phase, bool can_reshape) {
1039-
int con = maskShiftAmount(phase, this, BitsPerJavaInteger);
1062+
int con = mask_and_replace_shift_amount(phase, this, BitsPerJavaInteger);
10401063
if (con == 0) {
10411064
return nullptr;
10421065
}
@@ -1222,7 +1245,7 @@ Node* LShiftLNode::Identity(PhaseGVN* phase) {
12221245
// Also collapse nested left-shifts with constant rhs:
12231246
// (X << con1) << con2 ==> X << (con1 + con2)
12241247
Node *LShiftLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
1225-
int con = maskShiftAmount(phase, this, BitsPerJavaLong);
1248+
int con = mask_and_replace_shift_amount(phase, this, BitsPerJavaLong);
12261249
if (con == 0) {
12271250
return nullptr;
12281251
}
@@ -1443,7 +1466,7 @@ Node* RShiftNode::IdealIL(PhaseGVN* phase, bool can_reshape, BasicType bt) {
14431466
if (t1 == nullptr) {
14441467
return NodeSentinel; // Left input is an integer
14451468
}
1446-
int shift = maskShiftAmount(phase, this, bits_per_java_integer(bt));
1469+
int shift = mask_and_replace_shift_amount(phase, this, bits_per_java_integer(bt));
14471470
if (shift == 0) {
14481471
return NodeSentinel;
14491472
}
@@ -1473,7 +1496,7 @@ Node* RShiftINode::Ideal(PhaseGVN* phase, bool can_reshape) {
14731496
if (progress != nullptr) {
14741497
return progress;
14751498
}
1476-
int shift = maskShiftAmount(phase, this, BitsPerJavaInteger);
1499+
int shift = mask_and_replace_shift_amount(phase, this, BitsPerJavaInteger);
14771500
assert(shift != 0, "handled by IdealIL");
14781501

14791502
// Check for "(short[i] <<16)>>16" which simply sign-extends
@@ -1660,7 +1683,7 @@ Node* URShiftINode::Identity(PhaseGVN* phase) {
16601683

16611684
//------------------------------Ideal------------------------------------------
16621685
Node *URShiftINode::Ideal(PhaseGVN *phase, bool can_reshape) {
1663-
int con = maskShiftAmount(phase, this, BitsPerJavaInteger);
1686+
int con = mask_and_replace_shift_amount(phase, this, BitsPerJavaInteger);
16641687
if (con == 0) {
16651688
return nullptr;
16661689
}
@@ -1824,7 +1847,7 @@ Node* URShiftLNode::Identity(PhaseGVN* phase) {
18241847

18251848
//------------------------------Ideal------------------------------------------
18261849
Node *URShiftLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
1827-
int con = maskShiftAmount(phase, this, BitsPerJavaLong);
1850+
int con = mask_and_replace_shift_amount(phase, this, BitsPerJavaLong);
18281851
if (con == 0) {
18291852
return nullptr;
18301853
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation.
8+
*
9+
* This code is distributed in the hope that it will be useful, but WITHOUT
10+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+
* version 2 for more details (a copy is included in the LICENSE file that
13+
* accompanied this code).
14+
*
15+
* You should have received a copy of the GNU General Public License version
16+
* 2 along with this work; if not, write to the Free Software Foundation,
17+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+
*
19+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+
* or visit www.oracle.com if you need additional information or have any
21+
* questions.
22+
*/
23+
24+
/*
25+
* @test
26+
* @bug 8353345
27+
* @summary During the transformation (X << con1) << con2 ==> X << (con1 + con2) in IGVN,
28+
* we modified the inner shift, during the transformation of the outer shift without
29+
* removing it from the hashtable
30+
*
31+
* @run main/othervm -XX:CompileCommand=compileonly,*DoubleLShiftCrashDuringIGVN*::* -Xcomp DoubleLShiftCrashDuringIGVN
32+
*/
33+
34+
public class DoubleLShiftCrashDuringIGVN {
35+
public static long shift = 0;
36+
37+
public static int test() {
38+
int s = 1;
39+
40+
shift = 12;
41+
for (int i = 0; i < 4; i++) {
42+
for (int j = 0; j < 4; j++) {
43+
s <<= shift;
44+
}
45+
shift = 33;
46+
}
47+
return s;
48+
}
49+
50+
public static void main(String[] strArr) {
51+
test();
52+
}
53+
}

0 commit comments

Comments
 (0)