Skip to content

Commit f7a94fe

Browse files
author
Jatin Bhateja
committed
8352585: Add special case handling for Float16.max/min x86 backend
Reviewed-by: epeter, sviswanathan
1 parent 9c5ed23 commit f7a94fe

File tree

6 files changed

+254
-6
lines changed

6 files changed

+254
-6
lines changed

src/hotspot/cpu/x86/assembler_x86.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13810,6 +13810,16 @@ void Assembler::vcmpps(XMMRegister dst, XMMRegister nds, XMMRegister src, int co
1381013810
emit_int24((unsigned char)0xC2, (0xC0 | encode), (unsigned char)comparison);
1381113811
}
1381213812

13813+
void Assembler::evcmpsh(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src, ComparisonPredicateFP comparison) {
13814+
assert(VM_Version::supports_avx512_fp16(), "");
13815+
InstructionAttr attributes(Assembler::AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true);
13816+
attributes.set_is_evex_instruction();
13817+
attributes.set_embedded_opmask_register_specifier(mask);
13818+
attributes.reset_is_clear_context();
13819+
int encode = vex_prefix_and_encode(kdst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F_3A, &attributes);
13820+
emit_int24((unsigned char)0xC2, (0xC0 | encode), comparison);
13821+
}
13822+
1381313823
void Assembler::evcmpps(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
1381413824
ComparisonPredicateFP comparison, int vector_len) {
1381513825
assert(VM_Version::supports_evex(), "");

src/hotspot/cpu/x86/assembler_x86.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3195,6 +3195,9 @@ class Assembler : public AbstractAssembler {
31953195
void evcmpps(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
31963196
ComparisonPredicateFP comparison, int vector_len);
31973197

3198+
void evcmpsh(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
3199+
ComparisonPredicateFP comparison);
3200+
31983201
// Vector integer compares
31993202
void vpcmpgtd(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len);
32003203
void evpcmpd(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,

src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6680,8 +6680,6 @@ void C2_MacroAssembler::efp16sh(int opcode, XMMRegister dst, XMMRegister src1, X
66806680
case Op_SubHF: vsubsh(dst, src1, src2); break;
66816681
case Op_MulHF: vmulsh(dst, src1, src2); break;
66826682
case Op_DivHF: vdivsh(dst, src1, src2); break;
6683-
case Op_MaxHF: vmaxsh(dst, src1, src2); break;
6684-
case Op_MinHF: vminsh(dst, src1, src2); break;
66856683
default: assert(false, "%s", NodeClassNames[opcode]); break;
66866684
}
66876685
}
@@ -7091,3 +7089,48 @@ void C2_MacroAssembler::vector_saturating_op(int ideal_opc, BasicType elem_bt, X
70917089
vector_saturating_op(ideal_opc, elem_bt, dst, src1, src2, vlen_enc);
70927090
}
70937091
}
7092+
7093+
void C2_MacroAssembler::scalar_max_min_fp16(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2,
7094+
KRegister ktmp, XMMRegister xtmp1, XMMRegister xtmp2, int vlen_enc) {
7095+
if (opcode == Op_MaxHF) {
7096+
// Move sign bits of src2 to mask register.
7097+
evpmovw2m(ktmp, src2, vlen_enc);
7098+
// xtmp1 = src2 < 0 ? src2 : src1
7099+
evpblendmw(xtmp1, ktmp, src1, src2, true, vlen_enc);
7100+
// xtmp2 = src2 < 0 ? ? src1 : src2
7101+
evpblendmw(xtmp2, ktmp, src2, src1, true, vlen_enc);
7102+
// Idea behind above swapping is to make seconds source operand a +ve value.
7103+
// As per instruction semantic, if the values being compared are both 0.0s (of either sign), the value in
7104+
// the second source operand is returned. If only one value is a NaN (SNaN or QNaN) for this instruction,
7105+
// the second source operand, either a NaN or a valid floating-point value, is returned
7106+
// dst = max(xtmp1, xtmp2)
7107+
vmaxsh(dst, xtmp1, xtmp2);
7108+
// isNaN = is_unordered_quiet(xtmp1)
7109+
evcmpsh(ktmp, k0, xtmp1, xtmp1, Assembler::UNORD_Q);
7110+
// Final result is same as first source if its a NaN value,
7111+
// in case second operand holds a NaN value then as per above semantics
7112+
// result is same as second operand.
7113+
Assembler::evmovdquw(dst, ktmp, xtmp1, true, vlen_enc);
7114+
} else {
7115+
assert(opcode == Op_MinHF, "");
7116+
// Move sign bits of src1 to mask register.
7117+
evpmovw2m(ktmp, src1, vlen_enc);
7118+
// xtmp1 = src1 < 0 ? src2 : src1
7119+
evpblendmw(xtmp1, ktmp, src1, src2, true, vlen_enc);
7120+
// xtmp2 = src1 < 0 ? src1 : src2
7121+
evpblendmw(xtmp2, ktmp, src2, src1, true, vlen_enc);
7122+
// Idea behind above swapping is to make seconds source operand a -ve value.
7123+
// As per instruction semantics, if the values being compared are both 0.0s (of either sign), the value in
7124+
// the second source operand is returned.
7125+
// If only one value is a NaN (SNaN or QNaN) for this instruction, the second source operand, either a NaN
7126+
// or a valid floating-point value, is written to the result.
7127+
// dst = min(xtmp1, xtmp2)
7128+
vminsh(dst, xtmp1, xtmp2);
7129+
// isNaN = is_unordered_quiet(xtmp1)
7130+
evcmpsh(ktmp, k0, xtmp1, xtmp1, Assembler::UNORD_Q);
7131+
// Final result is same as first source if its a NaN value,
7132+
// in case second operand holds a NaN value then as per above semantics
7133+
// result is same as second operand.
7134+
Assembler::evmovdquw(dst, ktmp, xtmp1, true, vlen_enc);
7135+
}
7136+
}

src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,4 +584,6 @@
584584

585585
void select_from_two_vectors_evex(BasicType elem_bt, XMMRegister dst, XMMRegister src1, XMMRegister src2, int vlen_enc);
586586

587+
void scalar_max_min_fp16(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2,
588+
KRegister ktmp, XMMRegister xtmp1, XMMRegister xtmp2, int vlen_enc);
587589
#endif // CPU_X86_C2_MACROASSEMBLER_X86_HPP

src/hotspot/cpu/x86/x86.ad

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,11 +1461,14 @@ bool Matcher::match_rule_supported(int opcode) {
14611461
return false;
14621462
}
14631463
break;
1464+
case Op_MaxHF:
1465+
case Op_MinHF:
1466+
if (!VM_Version::supports_avx512vlbw()) {
1467+
return false;
1468+
} // fallthrough
14641469
case Op_AddHF:
14651470
case Op_DivHF:
14661471
case Op_FmaHF:
1467-
case Op_MaxHF:
1468-
case Op_MinHF:
14691472
case Op_MulHF:
14701473
case Op_ReinterpretS2HF:
14711474
case Op_ReinterpretHF2S:
@@ -10935,8 +10938,6 @@ instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2)
1093510938
%{
1093610939
match(Set dst (AddHF src1 src2));
1093710940
match(Set dst (DivHF src1 src2));
10938-
match(Set dst (MaxHF src1 src2));
10939-
match(Set dst (MinHF src1 src2));
1094010941
match(Set dst (MulHF src1 src2));
1094110942
match(Set dst (SubHF src1 src2));
1094210943
format %{ "scalar_binop_fp16 $dst, $src1, $src2" %}
@@ -10947,6 +10948,20 @@ instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2)
1094710948
ins_pipe(pipe_slow);
1094810949
%}
1094910950

10951+
instruct scalar_minmax_HF_reg(regF dst, regF src1, regF src2, kReg ktmp, regF xtmp1, regF xtmp2)
10952+
%{
10953+
match(Set dst (MaxHF src1 src2));
10954+
match(Set dst (MinHF src1 src2));
10955+
effect(TEMP_DEF dst, TEMP ktmp, TEMP xtmp1, TEMP xtmp2);
10956+
format %{ "scalar_min_max_fp16 $dst, $src1, $src2\t using $ktmp, $xtmp1 and $xtmp2 as TEMP" %}
10957+
ins_encode %{
10958+
int opcode = this->ideal_Opcode();
10959+
__ scalar_max_min_fp16(opcode, $dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister, $ktmp$$KRegister,
10960+
$xtmp1$$XMMRegister, $xtmp2$$XMMRegister, Assembler::AVX_128bit);
10961+
%}
10962+
ins_pipe( pipe_slow );
10963+
%}
10964+
1095010965
instruct scalar_fma_HF_reg(regF dst, regF src1, regF src2)
1095110966
%{
1095210967
match(Set dst (FmaHF src2 (Binary dst src1)));
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation.
8+
*
9+
* This code is distributed in the hope that it will be useful, but WITHOUT
10+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+
* version 2 for more details (a copy is included in the LICENSE file that
13+
* accompanied this code).
14+
*
15+
* You should have received a copy of the GNU General Public License version
16+
* 2 along with this work; if not, write to the Free Software Foundation,
17+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+
*
19+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+
* or visit www.oracle.com if you need additional information or have any
21+
* questions.
22+
*/
23+
package compiler.intrinsics.float16;
24+
25+
import compiler.lib.ir_framework.*;
26+
import jdk.incubator.vector.*;
27+
import java.util.Random;
28+
import jdk.test.lib.*;
29+
30+
/**
31+
* @test
32+
* @bug 8352585
33+
* @library /test/lib /
34+
* @summary Add special case handling for Float16.max/min x86 backend
35+
* @modules jdk.incubator.vector
36+
* @run driver compiler.intrinsics.float16.TestFloat16MaxMinSpecialValues
37+
*/
38+
39+
40+
public class TestFloat16MaxMinSpecialValues {
41+
public static Float16 POS_ZERO = Float16.valueOf(0.0f);
42+
public static Float16 NEG_ZERO = Float16.valueOf(-0.0f);
43+
public static Float16 SRC = Float16.valueOf(Float.MAX_VALUE);
44+
public static Random rd = Utils.getRandomInstance();
45+
46+
public static Float16 genNaN() {
47+
// IEEE 754 Half Precision QNaN Format
48+
// S EEEEE MMMMMMMMMM
49+
// X 11111 1XXXXXXXXX
50+
short sign = (short)(rd.nextBoolean() ? 1 << 15 : 0);
51+
short significand = (short)rd.nextInt(512);
52+
return Float16.shortBitsToFloat16((short)(sign | 0x7E00 | significand));
53+
}
54+
55+
public static boolean assertionCheck(Float16 actual, Float16 expected) {
56+
return !actual.equals(expected);
57+
}
58+
59+
public Float16 RES;
60+
61+
public static void main(String [] args) {
62+
TestFramework.runWithFlags("--add-modules=jdk.incubator.vector");
63+
}
64+
65+
@Test
66+
@IR(counts = {IRNode.MAX_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
67+
public Float16 testMaxNaNOperands(Float16 src1, Float16 src2) {
68+
return Float16.max(src1, src2);
69+
}
70+
71+
@Run(test = "testMaxNaNOperands")
72+
public void launchMaxNaNOperands() {
73+
Float16 NAN = null;
74+
for (int i = 0; i < 100; i++) {
75+
NAN = genNaN();
76+
RES = testMaxNaNOperands(SRC, NAN);
77+
if (assertionCheck(RES, NAN)) {
78+
throw new AssertionError("input1 = " + SRC.floatValue() + " input2 = NaN , expected = NaN, actual = " + RES.floatValue());
79+
}
80+
NAN = genNaN();
81+
RES = testMaxNaNOperands(NAN, SRC);
82+
if (assertionCheck(RES, NAN)) {
83+
throw new AssertionError("input1 = NaN, input2 = " + SRC.floatValue() + ", expected = NaN, actual = " + RES.floatValue());
84+
}
85+
NAN = genNaN();
86+
RES = testMaxNaNOperands(NAN, NAN);
87+
if (assertionCheck(RES, NAN)) {
88+
throw new AssertionError("input1 = NaN, input2 = NaN, expected = NaN, actual = " + RES.floatValue());
89+
}
90+
}
91+
}
92+
93+
@Test
94+
@IR(counts = {IRNode.MIN_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
95+
public Float16 testMinNaNOperands(Float16 src1, Float16 src2) {
96+
return Float16.min(src1, src2);
97+
}
98+
99+
@Run(test = "testMinNaNOperands")
100+
public void launchMinNaNOperands() {
101+
Float16 NAN = null;
102+
for (int i = 0; i < 100; i++) {
103+
NAN = genNaN();
104+
RES = testMinNaNOperands(SRC, NAN);
105+
if (assertionCheck(RES, NAN)) {
106+
throw new AssertionError("input1 = " + SRC.floatValue() + " input2 = NaN, expected = NaN, actual = " + RES.floatValue());
107+
}
108+
NAN = genNaN();
109+
RES = testMinNaNOperands(NAN, SRC);
110+
if (assertionCheck(RES, NAN)) {
111+
throw new AssertionError("input1 = NaN, input2 = " + SRC.floatValue() + ", expected = NaN, actual = " + RES.floatValue());
112+
}
113+
NAN = genNaN();
114+
RES = testMinNaNOperands(NAN, NAN);
115+
if (assertionCheck(RES, NAN)) {
116+
throw new AssertionError("input1 = NaN, input2 = NaN, expected = NaN, actual = " + RES.floatValue());
117+
}
118+
}
119+
}
120+
121+
@Test
122+
@IR(counts = {IRNode.MAX_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
123+
public Float16 testMaxZeroOperands(Float16 src1, Float16 src2) {
124+
return Float16.max(src1, src2);
125+
}
126+
127+
@Run(test = "testMaxZeroOperands")
128+
public void launchMaxZeroOperands() {
129+
RES = testMaxZeroOperands(POS_ZERO, NEG_ZERO);
130+
if (assertionCheck(RES, POS_ZERO)) {
131+
throw new AssertionError("input1 = +0.0, input2 = -0.0, expected = +0.0, actual = " + RES.floatValue());
132+
}
133+
RES = testMaxZeroOperands(NEG_ZERO, POS_ZERO);
134+
if (assertionCheck(RES, POS_ZERO)) {
135+
throw new AssertionError("input1 = -0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue());
136+
}
137+
RES = testMaxZeroOperands(POS_ZERO, POS_ZERO);
138+
if (assertionCheck(RES, POS_ZERO)) {
139+
throw new AssertionError("input1 = +0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue());
140+
}
141+
RES = testMaxZeroOperands(NEG_ZERO, NEG_ZERO);
142+
if (assertionCheck(RES, NEG_ZERO)) {
143+
throw new AssertionError("input1 = -0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue());
144+
}
145+
}
146+
147+
@Test
148+
@IR(counts = {IRNode.MIN_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
149+
public Float16 testMinZeroOperands(Float16 src1, Float16 src2) {
150+
return Float16.min(src1, src2);
151+
}
152+
153+
@Run(test = "testMinZeroOperands")
154+
public void launchMinZeroOperands() {
155+
RES = testMinZeroOperands(POS_ZERO, NEG_ZERO);
156+
if (assertionCheck(RES, NEG_ZERO)) {
157+
throw new AssertionError("input1 = +0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue());
158+
}
159+
160+
RES = testMinZeroOperands(NEG_ZERO, POS_ZERO);
161+
if (assertionCheck(RES, NEG_ZERO)) {
162+
throw new AssertionError("input1 = -0.0, input2 = +0.0, expected = -0.0, actual = " + RES.floatValue());
163+
}
164+
165+
RES = testMinZeroOperands(POS_ZERO, POS_ZERO);
166+
if (assertionCheck(RES, POS_ZERO)) {
167+
throw new AssertionError("input1 = +0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue());
168+
}
169+
170+
RES = testMinZeroOperands(NEG_ZERO, NEG_ZERO);
171+
if (assertionCheck(RES, NEG_ZERO)) {
172+
throw new AssertionError("input1 = -0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue());
173+
}
174+
}
175+
}

0 commit comments

Comments
 (0)