Skip to content

Commit 0f5a2b6

Browse files
committed
Audit Graal code for JDK-8351034: Add AVX-512 intrinsics for ML-DSA
1 parent 87eb39b commit 0f5a2b6

File tree

3 files changed

+18
-31
lines changed

3 files changed

+18
-31
lines changed

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/hotspot/replacements/HotSpotHashCodeSnippets.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
import jdk.graal.compiler.word.Word;
4646

4747
// @formatter:off
48-
@SyncPort(from = "https://github.com/openjdk/jdk/blob/250eb743c112fbcc45bf2b3ded1c644b19893577/src/hotspot/share/opto/library_call.cpp#L4667-L4801",
48+
@SyncPort(from = "https://github.com/openjdk/jdk/blob/c447a10225576bc59e1ba9477417367d2ac28511/src/hotspot/share/opto/library_call.cpp#L4662-L4796",
4949
sha1 = "c212d1dbff26d02d4d749e085263d4104895f1ba")
5050
// @formatter:on
5151
public class HotSpotHashCodeSnippets extends IdentityHashCodeSnippets {

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/amd64/AMD64AESEncryptOp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ static void loadKey(AMD64MacroAssembler masm, Register xmmDst, Register key, int
108108
}
109109

110110
static Register asXMMRegister(int index) {
111-
return AMD64.xmmRegistersSSE[index];
111+
return AMD64.xmmRegistersAVX512[index];
112112
}
113113

114114
@Override

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/amd64/AMD64SHA3Op.java

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.ConditionFlag.LessEqual;
2828
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.ConditionFlag.NotEqual;
29+
import static jdk.graal.compiler.lir.amd64.AMD64AESEncryptOp.asXMMRegister;
2930
import static jdk.graal.compiler.lir.amd64.AMD64LIRHelper.pointerConstant;
3031
import static jdk.graal.compiler.lir.amd64.AMD64LIRHelper.recordExternalAddress;
3132
import static jdk.vm.ci.amd64.AMD64.k1;
@@ -78,8 +79,8 @@
7879
import jdk.vm.ci.meta.Value;
7980

8081
// @formatter:off
81-
@SyncPort(from = "https://github.com/openjdk/jdk/blob/a937f6db30ab55b98dae25d5b6d041cf4b7b7291/src/hotspot/cpu/x86/stubGenerator_x86_64_sha3.cpp#L41-L337",
82-
sha1 = "d9d050bb8e4213f750eae298d436ace9a086b233")
82+
@SyncPort(from = "https://github.com/openjdk/jdk/blob/c447a10225576bc59e1ba9477417367d2ac28511/src/hotspot/cpu/x86/stubGenerator_x86_64_sha3.cpp#L43-L320",
83+
sha1 = "85dbee8cb0c0f6d8f37d07da6cf8b2f9f4fc8ce8")
8384
// @formatter:on
8485
public final class AMD64SHA3Op extends AMD64LIRInstruction {
8586

@@ -231,29 +232,16 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
231232
masm.kshiftrw(k1, k5, 4);
232233

233234
// load the state
234-
masm.evmovdqu64(xmm0, k5, new AMD64Address(state, 0));
235-
masm.evmovdqu64(xmm1, k5, new AMD64Address(state, 40));
236-
masm.evmovdqu64(xmm2, k5, new AMD64Address(state, 80));
237-
masm.evmovdqu64(xmm3, k5, new AMD64Address(state, 120));
238-
masm.evmovdqu64(xmm4, k5, new AMD64Address(state, 160));
235+
for (int i = 0; i < 5; i++) {
236+
masm.evmovdqu64(asXMMRegister(i), k5, new AMD64Address(state, i * 40));
237+
}
239238

240239
// load the permutation and rotation constants
241-
masm.evmovdqu64(xmm17, new AMD64Address(permsAndRots, 0));
242-
masm.evmovdqu64(xmm18, new AMD64Address(permsAndRots, 64));
243-
masm.evmovdqu64(xmm19, new AMD64Address(permsAndRots, 128));
244-
masm.evmovdqu64(xmm20, new AMD64Address(permsAndRots, 192));
245-
masm.evmovdqu64(xmm21, new AMD64Address(permsAndRots, 256));
246-
masm.evmovdqu64(xmm22, new AMD64Address(permsAndRots, 320));
247-
masm.evmovdqu64(xmm23, new AMD64Address(permsAndRots, 384));
248-
masm.evmovdqu64(xmm24, new AMD64Address(permsAndRots, 448));
249-
masm.evmovdqu64(xmm25, new AMD64Address(permsAndRots, 512));
250-
masm.evmovdqu64(xmm26, new AMD64Address(permsAndRots, 576));
251-
masm.evmovdqu64(xmm27, new AMD64Address(permsAndRots, 640));
252-
masm.evmovdqu64(xmm28, new AMD64Address(permsAndRots, 704));
253-
masm.evmovdqu64(xmm29, new AMD64Address(permsAndRots, 768));
254-
masm.evmovdqu64(xmm30, new AMD64Address(permsAndRots, 832));
255-
masm.evmovdqu64(xmm31, new AMD64Address(permsAndRots, 896));
240+
for (int i = 0; i < 15; i++) {
241+
masm.evmovdqu64(asXMMRegister(i + 17), new AMD64Address(permsAndRots, i * 64));
242+
}
256243

244+
masm.align(preferredLoopAlignment(crb));
257245
masm.bind(sha3Loop);
258246

259247
// there will be 24 keccak rounds
@@ -304,6 +292,7 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
304292
// The implementation closely follows the Java version, with the state
305293
// array "rows" in the lowest 5 64-bit slots of zmm0 - zmm4, i.e.
306294
// each row of the SHA3 specification is located in one zmm register.
295+
masm.align(preferredLoopAlignment(crb));
307296
masm.bind(rounds24Loop);
308297
masm.subl(roundsLeft, 1);
309298

@@ -330,7 +319,7 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
330319

331320
// Do the cyclical permutation of the 24 moving state elements
332321
// and the required rotations within each element (the combined
333-
// rho and sigma steps).
322+
// rho and pi steps).
334323
masm.evpermt2q(xmm4, xmm17, xmm3);
335324
masm.evpermt2q(xmm3, xmm18, xmm2);
336325
masm.evpermt2q(xmm2, xmm17, xmm1);
@@ -352,7 +341,7 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
352341
masm.evpermt2q(xmm2, xmm24, xmm4);
353342
masm.evpermt2q(xmm3, xmm25, xmm4);
354343
masm.evpermt2q(xmm4, xmm26, xmm5);
355-
// The combined rho and sigma steps are done.
344+
// The combined rho and pi steps are done.
356345

357346
// Do the chi step (the same operation on all 5 rows).
358347
// vpternlogq(x, 180, y, z) does x = x ^ (y & ~z).
@@ -394,11 +383,9 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
394383
}
395384

396385
// store the state
397-
masm.evmovdqu64(new AMD64Address(state, 0), k5, xmm0);
398-
masm.evmovdqu64(new AMD64Address(state, 40), k5, xmm1);
399-
masm.evmovdqu64(new AMD64Address(state, 80), k5, xmm2);
400-
masm.evmovdqu64(new AMD64Address(state, 120), k5, xmm3);
401-
masm.evmovdqu64(new AMD64Address(state, 160), k5, xmm4);
386+
for (int i = 0; i < 5; i++) {
387+
masm.evmovdqu64(new AMD64Address(state, i * 40), k5, asXMMRegister(i));
388+
}
402389

403390
masm.pop(r14);
404391
masm.pop(r13);

0 commit comments

Comments
 (0)