Skip to content

Commit 32df2d1

Browse files
author
Hamlin Li
committed
8365772: RISC-V: correctly prereserve NaN payload when converting from float to float16 in vector way
Reviewed-by: fyang, rehn
1 parent 19f0755 commit 32df2d1

File tree

3 files changed

+66
-20
lines changed

3 files changed

+66
-20
lines changed

src/hotspot/cpu/riscv/assembler_riscv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,6 +1988,7 @@ enum VectorMask {
19881988

19891989
// Vector Narrowing Integer Right Shift Instructions
19901990
INSN(vnsra_wi, 0b1010111, 0b011, 0b101101);
1991+
INSN(vnsrl_wi, 0b1010111, 0b011, 0b101100);
19911992

19921993
#undef INSN
19931994

src/hotspot/cpu/riscv/c2_MacroAssembler_riscv.cpp

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2248,41 +2248,80 @@ static void float_to_float16_v_slow_path(C2_MacroAssembler& masm,
22482248
#define __ masm.
22492249
VectorRegister dst = stub.data<0>();
22502250
VectorRegister src = stub.data<1>();
2251-
VectorRegister tmp = stub.data<2>();
2251+
VectorRegister vtmp = stub.data<2>();
2252+
assert_different_registers(dst, src, vtmp);
2253+
22522254
__ bind(stub.entry());
22532255

2256+
// Active elements (NaNs) are marked in v0 mask register.
22542257
// mul is already set to mf2 in float_to_float16_v.
22552258

2256-
// preserve the payloads of non-canonical NaNs.
2257-
__ vnsra_wi(dst, src, 13, Assembler::v0_t);
2258-
2259-
// preserve the sign bit.
2260-
__ vnsra_wi(tmp, src, 26, Assembler::v0_t);
2261-
__ vsll_vi(tmp, tmp, 10, Assembler::v0_t);
2262-
__ mv(t0, 0x3ff);
2263-
__ vor_vx(tmp, tmp, t0, Assembler::v0_t);
2264-
2265-
// get the result by merging sign bit and payloads of preserved non-canonical NaNs.
2266-
__ vand_vv(dst, dst, tmp, Assembler::v0_t);
2259+
// Float (32 bits)
2260+
// Bit: 31 30 to 23 22 to 0
2261+
// +---+------------------+-----------------------------+
2262+
// | S | Exponent | Mantissa (Fraction) |
2263+
// +---+------------------+-----------------------------+
2264+
// 1 bit 8 bits 23 bits
2265+
//
2266+
// Float (16 bits)
2267+
// Bit: 15 14 to 10 9 to 0
2268+
// +---+----------------+------------------+
2269+
// | S | Exponent | Mantissa |
2270+
// +---+----------------+------------------+
2271+
// 1 bit 5 bits 10 bits
2272+
const int fp_sign_bits = 1;
2273+
const int fp32_bits = 32;
2274+
const int fp32_mantissa_2nd_part_bits = 9;
2275+
const int fp32_mantissa_3rd_part_bits = 4;
2276+
const int fp16_exponent_bits = 5;
2277+
const int fp16_mantissa_bits = 10;
2278+
2279+
// preserve the sign bit and exponent, clear mantissa.
2280+
__ vnsra_wi(dst, src, fp32_bits - fp_sign_bits - fp16_exponent_bits, Assembler::v0_t);
2281+
__ vsll_vi(dst, dst, fp16_mantissa_bits, Assembler::v0_t);
2282+
2283+
// Preserve high order bit of float NaN in the
2284+
// binary16 result NaN (tenth bit); OR in remaining
2285+
// bits into lower 9 bits of binary 16 significand.
2286+
// | (doppel & 0x007f_e000) >> 13 // 10 bits
2287+
// | (doppel & 0x0000_1ff0) >> 4 // 9 bits
2288+
// | (doppel & 0x0000_000f)); // 4 bits
2289+
//
2290+
// Check j.l.Float.floatToFloat16 for more information.
2291+
// 10 bits
2292+
__ vnsrl_wi(vtmp, src, fp32_mantissa_2nd_part_bits + fp32_mantissa_3rd_part_bits, Assembler::v0_t);
2293+
__ mv(t0, 0x3ff); // retain first part of mantissa in a float 32
2294+
__ vand_vx(vtmp, vtmp, t0, Assembler::v0_t);
2295+
__ vor_vv(dst, dst, vtmp, Assembler::v0_t);
2296+
// 9 bits
2297+
__ vnsrl_wi(vtmp, src, fp32_mantissa_3rd_part_bits, Assembler::v0_t);
2298+
__ mv(t0, 0x1ff); // retain second part of mantissa in a float 32
2299+
__ vand_vx(vtmp, vtmp, t0, Assembler::v0_t);
2300+
__ vor_vv(dst, dst, vtmp, Assembler::v0_t);
2301+
// 4 bits
2302+
// Narrow shift is necessary to move data from 32 bits element to 16 bits element in vector register.
2303+
__ vnsrl_wi(vtmp, src, 0, Assembler::v0_t);
2304+
__ vand_vi(vtmp, vtmp, 0xf, Assembler::v0_t);
2305+
__ vor_vv(dst, dst, vtmp, Assembler::v0_t);
22672306

22682307
__ j(stub.continuation());
22692308
#undef __
22702309
}
22712310

22722311
// j.l.Float.float16ToFloat
2273-
void C2_MacroAssembler::float_to_float16_v(VectorRegister dst, VectorRegister src, VectorRegister vtmp,
2274-
Register tmp, uint vector_length) {
2312+
void C2_MacroAssembler::float_to_float16_v(VectorRegister dst, VectorRegister src,
2313+
VectorRegister vtmp, Register tmp, uint vector_length) {
22752314
assert_different_registers(dst, src, vtmp);
22762315

22772316
auto stub = C2CodeStub::make<VectorRegister, VectorRegister, VectorRegister>
2278-
(dst, src, vtmp, 28, float_to_float16_v_slow_path);
2317+
(dst, src, vtmp, 56, float_to_float16_v_slow_path);
22792318

22802319
// On riscv, NaN needs a special process as vfncvt_f_f_w does not work in that case.
22812320

22822321
vsetvli_helper(BasicType::T_FLOAT, vector_length, Assembler::m1);
22832322

22842323
// check whether there is a NaN.
2285-
// replace v_fclass with vmseq_vv as performance optimization.
2324+
// replace v_fclass with vmfne_vv as performance optimization.
22862325
vmfne_vv(v0, src, src);
22872326
vcpop_m(t0, v0);
22882327

test/hotspot/jtreg/compiler/vectorization/TestFloatConversionsVectorNaN.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
/**
2525
* @test
26+
* @key randomness
2627
* @bug 8320646
2728
* @summary Auto-vectorize Float.floatToFloat16, Float.float16ToFloat APIs, with NaN
2829
* @requires vm.compiler2.enabled
@@ -37,9 +38,11 @@
3738
package compiler.vectorization;
3839

3940
import java.util.HexFormat;
41+
import java.util.Random;
4042

4143
import compiler.lib.ir_framework.*;
4244
import jdk.test.lib.Asserts;
45+
import jdk.test.lib.Utils;
4346

4447
public class TestFloatConversionsVectorNaN {
4548
private static final int ARRLEN = 1024;
@@ -79,14 +82,16 @@ public void test_float_float16(short[] sout, float[] finp) {
7982

8083
@Run(test = {"test_float_float16"}, mode = RunMode.STANDALONE)
8184
public void kernel_test_float_float16() {
85+
Random rand = Utils.getRandomInstance();
8286
int errno = 0;
8387
finp = new float[ARRLEN];
8488
sout = new short[ARRLEN];
8589

8690
// Setup
8791
for (int i = 0; i < ARRLEN; i++) {
88-
if (i%39 == 0) {
89-
int x = 0x7f800000 + ((i/39) << 13);
92+
if (i%3 == 0) {
93+
int shift = rand.nextInt(13+1);
94+
int x = 0x7f800000 + ((i/39) << shift);
9095
x = (i%2 == 0) ? x : (x | 0x80000000);
9196
finp[i] = Float.intBitsToFloat(x);
9297
} else {
@@ -128,7 +133,8 @@ public void kernel_test_float_float16() {
128133

129134
static int assertEquals(int idx, float f, short expected, short actual) {
130135
HexFormat hf = HexFormat.of();
131-
String msg = "floatToFloat16 wrong result: idx: " + idx + ", \t" + f +
136+
String msg = "floatToFloat16 wrong result: idx: " + idx +
137+
", \t" + f + ", hex: " + Integer.toHexString(Float.floatToRawIntBits(f)) +
132138
",\t expected: " + hf.toHexDigits(expected) +
133139
",\t actual: " + hf.toHexDigits(actual);
134140
if ((expected & 0x7c00) != 0x7c00) {
@@ -167,7 +173,7 @@ public void kernel_test_float16_float() {
167173

168174
// Setup
169175
for (int i = 0; i < ARRLEN; i++) {
170-
if (i%39 == 0) {
176+
if (i%3 == 0) {
171177
int x = 0x7c00 + i;
172178
x = (i%2 == 0) ? x : (x | 0x8000);
173179
sinp[i] = (short)x;

0 commit comments

Comments
 (0)