Skip to content

Commit 13c1148

Browse files
Scott GibbonsSandhya Viswanathan
authored andcommitted
8321599: Data loss in AVX3 Base64 decoding
Reviewed-by: sviswanathan, kvn
1 parent 028ec7e commit 13c1148

File tree

2 files changed

+124
-3
lines changed

2 files changed

+124
-3
lines changed

src/hotspot/cpu/x86/stubGenerator_x86_64.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2003, 2023, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2003, 2024, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -2318,7 +2318,7 @@ address StubGenerator::generate_base64_decodeBlock() {
23182318
const Register isURL = c_rarg5;// Base64 or URL character set
23192319
__ movl(isMIME, Address(rbp, 2 * wordSize));
23202320
#else
2321-
const Address dp_mem(rbp, 6 * wordSize); // length is on stack on Win64
2321+
const Address dp_mem(rbp, 6 * wordSize); // length is on stack on Win64
23222322
const Address isURL_mem(rbp, 7 * wordSize);
23232323
const Register isURL = r10; // pick the volatile windows register
23242324
const Register dp = r12;
@@ -2540,10 +2540,12 @@ address StubGenerator::generate_base64_decodeBlock() {
25402540
// output_size in r13
25412541

25422542
// Strip pad characters, if any, and adjust length and mask
2543+
__ addq(length, start_offset);
25432544
__ cmpb(Address(source, length, Address::times_1, -1), '=');
25442545
__ jcc(Assembler::equal, L_padding);
25452546

25462547
__ BIND(L_donePadding);
2548+
__ subq(length, start_offset);
25472549

25482550
// Output size is (64 - output_size), output mask is (all 1s >> output_size).
25492551
__ kmovql(input_mask, rax);

test/hotspot/jtreg/compiler/intrinsics/base64/TestBase64.java

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -46,10 +46,13 @@
4646
import java.util.Base64;
4747
import java.util.Base64.Decoder;
4848
import java.util.Base64.Encoder;
49+
import java.util.HexFormat;
4950
import java.util.Objects;
5051
import java.util.Random;
5152
import java.util.Arrays;
5253

54+
import static java.lang.String.format;
55+
5356
import compiler.whitebox.CompilerWhiteBoxTest;
5457
import jdk.test.whitebox.code.Compiler;
5558
import jtreg.SkippedException;
@@ -69,6 +72,8 @@ public static void main(String[] args) throws Exception {
6972

7073
warmup();
7174

75+
length_checks();
76+
7277
test0(FileType.ASCII, Base64Type.BASIC, Base64.getEncoder(), Base64.getDecoder(),"plain.txt", "baseEncode.txt", iters);
7378
test0(FileType.ASCII, Base64Type.URLSAFE, Base64.getUrlEncoder(), Base64.getUrlDecoder(),"plain.txt", "urlEncode.txt", iters);
7479
test0(FileType.ASCII, Base64Type.MIME, Base64.getMimeEncoder(), Base64.getMimeDecoder(),"plain.txt", "mimeEncode.txt", iters);
@@ -302,4 +307,118 @@ private static final byte getBadBase64Char(Base64Type b64Type) {
302307
throw new InternalError("Internal test error: getBadBase64Char called with unknown Base64Type value");
303308
}
304309
}
310+
311+
static final int POSITIONS = 30_000;
312+
static final int BASE_LENGTH = 256;
313+
static final HexFormat HEX_FORMAT = HexFormat.of().withUpperCase().withDelimiter(" ");
314+
315+
static int[] plainOffsets = new int[POSITIONS + 1];
316+
static byte[] plainBytes;
317+
static int[] base64Offsets = new int[POSITIONS + 1];
318+
static byte[] base64Bytes;
319+
320+
static {
321+
// Set up ByteBuffer with characters to be encoded
322+
int plainLength = 0;
323+
for (int i = 0; i < plainOffsets.length; i++) {
324+
plainOffsets[i] = plainLength;
325+
int positionLength = (BASE_LENGTH + i) % 2048;
326+
plainLength += positionLength;
327+
}
328+
// Put one of each possible byte value into ByteBuffer
329+
plainBytes = new byte[plainLength];
330+
for (int i = 0; i < plainBytes.length; i++) {
331+
plainBytes[i] = (byte) i;
332+
}
333+
334+
// Grab various slices of the ByteBuffer and encode them
335+
ByteBuffer plainBuffer = ByteBuffer.wrap(plainBytes);
336+
int base64Length = 0;
337+
for (int i = 0; i < POSITIONS; i++) {
338+
base64Offsets[i] = base64Length;
339+
int offset = plainOffsets[i];
340+
int length = plainOffsets[i + 1] - offset;
341+
ByteBuffer plainSlice = plainBuffer.slice(offset, length);
342+
base64Length += Base64.getEncoder().encode(plainSlice).remaining();
343+
}
344+
345+
// Decode the slices created above and ensure lengths match
346+
base64Offsets[base64Offsets.length - 1] = base64Length;
347+
base64Bytes = new byte[base64Length];
348+
for (int i = 0; i < POSITIONS; i++) {
349+
int plainOffset = plainOffsets[i];
350+
ByteBuffer plainSlice = plainBuffer.slice(plainOffset, plainOffsets[i + 1] - plainOffset);
351+
ByteBuffer encodedBytes = Base64.getEncoder().encode(plainSlice);
352+
int base64Offset = base64Offsets[i];
353+
int expectedLength = base64Offsets[i + 1] - base64Offset;
354+
if (expectedLength != encodedBytes.remaining()) {
355+
throw new IllegalStateException(format("Unexpected length: %s <> %s", encodedBytes.remaining(), expectedLength));
356+
}
357+
encodedBytes.get(base64Bytes, base64Offset, expectedLength);
358+
}
359+
}
360+
361+
public static void length_checks() {
362+
decodeAndCheck();
363+
encodeDecode();
364+
System.out.println("Test complete, no invalid decodes detected");
365+
}
366+
367+
// Use ByteBuffer to cause decode() to use the base + offset form of decode
368+
// Checks for bug reported in JDK-8321599 where padding characters appear
369+
// within the beginning of the ByteBuffer *before* the offset. This caused
370+
// the decoded string length to be off by 1 or 2 bytes.
371+
static void decodeAndCheck() {
372+
for (int i = 0; i < POSITIONS; i++) {
373+
ByteBuffer encodedBytes = base64BytesAtPosition(i);
374+
ByteBuffer decodedBytes = Base64.getDecoder().decode(encodedBytes);
375+
376+
if (!decodedBytes.equals(plainBytesAtPosition(i))) {
377+
String base64String = base64StringAtPosition(i);
378+
String plainHexString = plainHexStringAtPosition(i);
379+
String decodedHexString = HEX_FORMAT.formatHex(decodedBytes.array(), decodedBytes.arrayOffset() + decodedBytes.position(), decodedBytes.arrayOffset() + decodedBytes.limit());
380+
throw new IllegalStateException(format("Mismatch for %s\n\nExpected:\n%s\n\nActual:\n%s", base64String, plainHexString, decodedHexString));
381+
}
382+
}
383+
}
384+
385+
// Encode strings of lengths 1-1K, decode, and ensure length and contents correct.
386+
// This checks that padding characters are properly handled by decode.
387+
static void encodeDecode() {
388+
String allAs = "A(=)".repeat(128);
389+
for (int i = 1; i <= 512; i++) {
390+
String encStr = Base64.getEncoder().encodeToString(allAs.substring(0, i).getBytes());
391+
String decStr = new String(Base64.getDecoder().decode(encStr));
392+
393+
if ((decStr.length() != allAs.substring(0, i).length()) ||
394+
(!Objects.equals(decStr, allAs.substring(0, i)))
395+
) {
396+
throw new IllegalStateException(format("Mismatch: Expected: %s\n Actual: %s\n", allAs.substring(0, i), decStr));
397+
}
398+
}
399+
}
400+
401+
static ByteBuffer plainBytesAtPosition(int position) {
402+
int offset = plainOffsets[position];
403+
int length = plainOffsets[position + 1] - offset;
404+
return ByteBuffer.wrap(plainBytes, offset, length);
405+
}
406+
407+
static String plainHexStringAtPosition(int position) {
408+
int offset = plainOffsets[position];
409+
int length = plainOffsets[position + 1] - offset;
410+
return HEX_FORMAT.formatHex(plainBytes, offset, offset + length);
411+
}
412+
413+
static String base64StringAtPosition(int position) {
414+
int offset = base64Offsets[position];
415+
int length = base64Offsets[position + 1] - offset;
416+
return new String(base64Bytes, offset, length, StandardCharsets.UTF_8);
417+
}
418+
419+
static ByteBuffer base64BytesAtPosition(int position) {
420+
int offset = base64Offsets[position];
421+
int length = base64Offsets[position + 1] - offset;
422+
return ByteBuffer.wrap(base64Bytes, offset, length);
423+
}
305424
}

0 commit comments

Comments
 (0)