Skip to content

Commit 465c8e6

Browse files
ferakoczadinn
authored andcommitted
8349721: Add aarch64 intrinsics for ML-KEM
Reviewed-by: adinn
1 parent 1ad869f commit 465c8e6

File tree

20 files changed

+2523
-149
lines changed

20 files changed

+2523
-149
lines changed

src/hotspot/cpu/aarch64/register_aarch64.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,3 @@ const char* PRegister::PRegisterImpl::name() const {
5858
};
5959
return is_valid() ? names[encoding()] : "pnoreg";
6060
}
61-
62-
// convenience methods for splitting 8-way vector register sequences
63-
// in half -- needed because vector operations can normally only be
64-
// benefit from 4-way instruction parallelism
65-
66-
VSeq<4> vs_front(const VSeq<8>& v) {
67-
return VSeq<4>(v.base(), v.delta());
68-
}
69-
70-
VSeq<4> vs_back(const VSeq<8>& v) {
71-
return VSeq<4>(v.base() + 4 * v.delta(), v.delta());
72-
}
73-
74-
VSeq<4> vs_even(const VSeq<8>& v) {
75-
return VSeq<4>(v.base(), v.delta() * 2);
76-
}
77-
78-
VSeq<4> vs_odd(const VSeq<8>& v) {
79-
return VSeq<4>(v.base() + 1, v.delta() * 2);
80-
}

src/hotspot/cpu/aarch64/register_aarch64.hpp

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -436,19 +436,20 @@ enum RC { rc_bad, rc_int, rc_float, rc_predicate, rc_stack };
436436
// inputs into front and back halves or odd and even halves (see
437437
// convenience methods below).
438438

439+
// helper macro for computing register masks
440+
#define VS_MASK_BIT(base, delta, i) (1 << (base + delta * i))
441+
439442
template<int N> class VSeq {
440443
static_assert(N >= 2, "vector sequence length must be greater than 1");
441-
static_assert(N <= 8, "vector sequence length must not exceed 8");
442-
static_assert((N & (N - 1)) == 0, "vector sequence length must be power of two");
443444
private:
444445
int _base; // index of first register in sequence
445446
int _delta; // increment to derive successive indices
446447
public:
447448
VSeq(FloatRegister base_reg, int delta = 1) : VSeq(base_reg->encoding(), delta) { }
448449
VSeq(int base, int delta = 1) : _base(base), _delta(delta) {
449-
assert (_base >= 0, "invalid base register");
450-
assert (_delta >= 0, "invalid register delta");
451-
assert ((_base + (N - 1) * _delta) < 32, "range exceeded");
450+
assert (_base >= 0 && _base <= 31, "invalid base register");
451+
assert ((_base + (N - 1) * _delta) >= 0, "register range underflow");
452+
assert ((_base + (N - 1) * _delta) < 32, "register range overflow");
452453
}
453454
// indexed access to sequence
454455
FloatRegister operator [](int i) const {
@@ -457,27 +458,89 @@ template<int N> class VSeq {
457458
}
458459
int mask() const {
459460
int m = 0;
460-
int bit = 1 << _base;
461461
for (int i = 0; i < N; i++) {
462-
m |= bit << (i * _delta);
462+
m |= VS_MASK_BIT(_base, _delta, i);
463463
}
464464
return m;
465465
}
466466
int base() const { return _base; }
467467
int delta() const { return _delta; }
468+
bool is_constant() const { return _delta == 0; }
468469
};
469470

470-
// declare convenience methods for splitting vector register sequences
471-
472-
VSeq<4> vs_front(const VSeq<8>& v);
473-
VSeq<4> vs_back(const VSeq<8>& v);
474-
VSeq<4> vs_even(const VSeq<8>& v);
475-
VSeq<4> vs_odd(const VSeq<8>& v);
476-
477-
// methods for use in asserts to check VSeq inputs and oupts are
471+
// methods for use in asserts to check VSeq inputs and outputs are
478472
// either disjoint or equal
479473

480474
template<int N, int M> bool vs_disjoint(const VSeq<N>& n, const VSeq<M>& m) { return (n.mask() & m.mask()) == 0; }
481475
template<int N> bool vs_same(const VSeq<N>& n, const VSeq<N>& m) { return n.mask() == m.mask(); }
482476

477+
// method for use in asserts to check whether registers appearing in
478+
// an output sequence will be written before they are read from an
479+
// input sequence.
480+
481+
template<int N> bool vs_write_before_read(const VSeq<N>& vout, const VSeq<N>& vin) {
482+
int b_in = vin.base();
483+
int d_in = vin.delta();
484+
int b_out = vout.base();
485+
int d_out = vout.delta();
486+
int bit_in = 1 << b_in;
487+
int bit_out = 1 << b_out;
488+
int mask_read = vin.mask(); // all pending reads
489+
int mask_write = 0; // no writes as yet
490+
491+
492+
for (int i = 0; i < N - 1; i++) {
493+
// check whether a pending read clashes with a write
494+
if ((mask_write & mask_read) != 0) {
495+
return true;
496+
}
497+
// remove the pending input (so long as this is a constant
498+
// sequence)
499+
if (d_in != 0) {
500+
mask_read ^= VS_MASK_BIT(b_in, d_in, i);
501+
}
502+
// record the next write
503+
mask_write |= VS_MASK_BIT(b_out, d_out, i);
504+
}
505+
// no write before read
506+
return false;
507+
}
508+
509+
// convenience methods for splitting 8-way or 4-way vector register
510+
// sequences in half -- needed because vector operations can normally
511+
// benefit from 4-way instruction parallelism or, occasionally, 2-way
512+
// parallelism
513+
514+
template<int N>
515+
VSeq<N/2> vs_front(const VSeq<N>& v) {
516+
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
517+
return VSeq<N/2>(v.base(), v.delta());
518+
}
519+
520+
template<int N>
521+
VSeq<N/2> vs_back(const VSeq<N>& v) {
522+
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
523+
return VSeq<N/2>(v.base() + N / 2 * v.delta(), v.delta());
524+
}
525+
526+
template<int N>
527+
VSeq<N/2> vs_even(const VSeq<N>& v) {
528+
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
529+
return VSeq<N/2>(v.base(), v.delta() * 2);
530+
}
531+
532+
template<int N>
533+
VSeq<N/2> vs_odd(const VSeq<N>& v) {
534+
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
535+
return VSeq<N/2>(v.base() + v.delta(), v.delta() * 2);
536+
}
537+
538+
// convenience method to construct a vector register sequence that
539+
// indexes its elements in reverse order to the original
540+
541+
template<int N>
542+
VSeq<N> vs_reverse(const VSeq<N>& v) {
543+
return VSeq<N>(v.base() + (N - 1) * v.delta(), -v.delta());
544+
}
545+
483546
#endif // CPU_AARCH64_REGISTER_AARCH64_HPP

src/hotspot/cpu/aarch64/stubDeclarations_aarch64.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
do_arch_blob, \
4545
do_arch_entry, \
4646
do_arch_entry_init) \
47-
do_arch_blob(compiler, 55000 ZGC_ONLY(+5000)) \
47+
do_arch_blob(compiler, 75000 ZGC_ONLY(+5000)) \
4848
do_stub(compiler, vector_iota_indices) \
4949
do_arch_entry(aarch64, compiler, vector_iota_indices, \
5050
vector_iota_indices, vector_iota_indices) \

0 commit comments

Comments
 (0)