Skip to content

Commit e567d20

Browse files
Clean up and commonize the Arm vector types
1 parent db51ef1 commit e567d20

File tree

1 file changed

+43
-90
lines changed

1 file changed

+43
-90
lines changed

src/arm/aesni_common.jl

Lines changed: 43 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,52 @@ end
1414

1515
const LITTLE_ENDIAN::Bool = ENDIAN_BOM 0x04030201
1616

17-
const uint64x2_lvec = NTuple{2, VecElement{UInt64}}
18-
struct uint64x2
19-
data::uint64x2_lvec
20-
end
21-
@inline Base.convert(::Type{uint64x2}, x::UInt128) = unsafe_load(Ptr{uint64x2}(pointer_from_objref(Ref(x))))
22-
@inline Base.convert(::Type{UInt128}, x::uint64x2) = unsafe_load(Ptr{UInt128}(pointer_from_objref(Ref(x))))
23-
@inline UInt128(x::uint64x2) = convert(UInt128, x)
24-
@inline uint64x2(x::UInt128) = convert(uint64x2, x)
25-
@inline Base.convert(::Type{uint64x2}, x::Union{Signed, Unsigned}) = convert(uint64x2, UInt128(x))
26-
@inline Base.convert(::Type{T}, x::uint64x2) where T <: Union{Signed, Unsigned} = convert(T, UInt128(x))
27-
28-
@inline uint64x2(hi::UInt64, lo::UInt64) = @static if LITTLE_ENDIAN
29-
uint64x2((VecElement(lo), VecElement(hi)))
30-
else
31-
uint64x2((VecElement(hi), VecElement(lo)))
17+
abstract type ArmVec128 end
18+
@inline Base.convert(::Type{T}, x::ArmVec128) where {T<:Union{ArmVec128, UInt128}} =
19+
unsafe_load(Ptr{T}(pointer_from_objref(Ref(x))))
20+
@inline Base.convert(::Type{T}, x::UInt128) where {T<:ArmVec128} =
21+
unsafe_load(Ptr{T}(pointer_from_objref(Ref(x))))
22+
@inline UInt128(x::ArmVec128) = convert(UInt128, x)
23+
@inline (::Type{T})(x::Union{ArmVec128, UInt128}) where {T<:ArmVec128} = convert(T, x)
24+
@inline Base.convert(::Type{T}, x::Union{Signed, Unsigned}) where {T<:ArmVec128} =
25+
convert(T, UInt128(x))
26+
@inline Base.convert(::Type{T}, x::ArmVec128) where T <: Union{Signed, Unsigned} =
27+
convert(T, UInt128(x))
28+
29+
const VEC_FLAVORS = [(2^(7 - i) => 2^i) for i in 3:6]
30+
for (num_elems, elem_bits) in VEC_FLAVORS
31+
vec_symb = Symbol("uint$(elem_bits)x$(num_elems)")
32+
lvec_symb = Symbol("uint$(elem_bits)x$(num_elems)_lvec")
33+
elem_ty_symb = Symbol("UInt$elem_bits")
34+
llvm_xor =
35+
"""%3 = xor <$(num_elems) x i$(elem_bits)> %1, %0
36+
ret <$(num_elems) x i$(elem_bits)> %3"""
37+
@eval begin
38+
const $lvec_symb = NTuple{$num_elems, VecElement{$elem_ty_symb}}
39+
struct $vec_symb <: ArmVec128
40+
data::$lvec_symb
41+
end
42+
@inline $vec_symb(x::Union{UInt128, ArmVec128}) = convert($vec_symb, x)
43+
44+
@inline function $vec_symb(bytes::Vararg{$elem_ty_symb, $num_elems})
45+
bytes_prepped = bytes
46+
@static if $LITTLE_ENDIAN
47+
bytes_prepped = reverse(bytes_prepped)
48+
end
49+
bytes_vec::$lvec_symb = VecElement.(bytes_prepped)
50+
return $vec_symb(bytes_vec)
51+
end
52+
53+
@inline Base.zero(::Type{$vec_symb}) = convert($vec_symb, zero(UInt128))
54+
@inline Base.xor(a::$vec_symb, b::$vec_symb) = llvmcall(
55+
$llvm_xor,
56+
$lvec_symb, Tuple{$lvec_symb, $lvec_symb},
57+
a.data, b.data,
58+
) |> $vec_symb
59+
end
3260
end
3361

34-
@inline Base.zero(::Type{uint64x2}) = convert(uint64x2, zero(UInt128))
3562
@inline Base.one(::Type{uint64x2}) = uint64x2(zero(UInt64), one(UInt64))
36-
@inline Base.xor(a::uint64x2, b::uint64x2) = llvmcall(
37-
"""%3 = xor <2 x i64> %1, %0
38-
ret <2 x i64> %3""",
39-
uint64x2_lvec, Tuple{uint64x2_lvec, uint64x2_lvec},
40-
a.data, b.data,
41-
) |> uint64x2
4263
@inline (+)(a::uint64x2, b::uint64x2) = llvmcall(
4364
"""%3 = add <2 x i64> %1, %0
4465
ret <2 x i64> %3""",
@@ -47,74 +68,6 @@ end
4768
) |> uint64x2
4869
@inline (+)(a::uint64x2, b::Integer) = a + uint64x2(UInt128(b))
4970

50-
const uint8x16_lvec = NTuple{16, VecElement{UInt8}}
51-
struct uint8x16
52-
data::uint8x16_lvec
53-
end
54-
@inline Base.convert(::Type{uint64x2}, x::uint8x16) = unsafe_load(Ptr{uint64x2}(pointer_from_objref(Ref(x))))
55-
@inline Base.convert(::Type{uint8x16}, x::uint64x2) = unsafe_load(Ptr{uint8x16}(pointer_from_objref(Ref(x))))
56-
@inline uint8x16(x::uint64x2) = convert(uint8x16, x)
57-
@inline uint64x2(x::uint8x16) = convert(uint64x2, x)
58-
@inline Base.convert(::Type{uint8x16}, x::UInt128) = unsafe_load(Ptr{uint8x16}(pointer_from_objref(Ref(x))))
59-
@inline Base.convert(::Type{UInt128}, x::uint8x16) = unsafe_load(Ptr{UInt128}(pointer_from_objref(Ref(x))))
60-
@inline UInt128(x::uint8x16) = convert(UInt128, x)
61-
@inline uint8x16(x::UInt128) = convert(uint8x16, x)
62-
@inline Base.convert(::Type{uint8x16}, x::Union{Signed, Unsigned}) = convert(uint8x16, UInt128(x))
63-
@inline Base.convert(::Type{T}, x::uint8x16) where T <: Union{Signed, Unsigned} = convert(T, UInt128(x))
64-
65-
@inline function uint8x16(bytes::Vararg{UInt8, 16})
66-
bytes_prepped = bytes
67-
@static if LITTLE_ENDIAN
68-
bytes_prepped = reverse(bytes_prepped)
69-
end
70-
bytes_vec::uint8x16_lvec = VecElement.(bytes_prepped)
71-
return uint8x16(bytes_vec)
72-
end
73-
74-
@inline Base.zero(::Type{uint8x16}) = convert(uint8x16, zero(UInt128))
75-
@inline Base.xor(a::uint8x16, b::uint8x16) = llvmcall(
76-
"""%3 = xor <16 x i8> %1, %0
77-
ret <16 x i8> %3""",
78-
uint8x16_lvec, Tuple{uint8x16_lvec, uint8x16_lvec},
79-
a.data, b.data,
80-
) |> uint8x16
81-
82-
const uint32x4_lvec = NTuple{4, VecElement{UInt32}}
83-
struct uint32x4
84-
data::uint32x4_lvec
85-
end
86-
@inline Base.convert(::Type{uint64x2}, x::uint32x4) = unsafe_load(Ptr{uint64x2}(pointer_from_objref(Ref(x))))
87-
@inline Base.convert(::Type{uint32x4}, x::uint64x2) = unsafe_load(Ptr{uint32x4}(pointer_from_objref(Ref(x))))
88-
@inline uint32x4(x::uint64x2) = convert(uint32x4, x)
89-
@inline uint64x2(x::uint32x4) = convert(uint64x2, x)
90-
@inline Base.convert(::Type{uint8x16}, x::uint32x4) = unsafe_load(Ptr{uint8x16}(pointer_from_objref(Ref(x))))
91-
@inline Base.convert(::Type{uint32x4}, x::uint8x16) = unsafe_load(Ptr{uint32x4}(pointer_from_objref(Ref(x))))
92-
@inline uint32x4(x::uint8x16) = convert(uint32x4, x)
93-
@inline uint8x16(x::uint32x4) = convert(uint8x16, x)
94-
@inline Base.convert(::Type{uint32x4}, x::UInt128) = unsafe_load(Ptr{uint32x4}(pointer_from_objref(Ref(x))))
95-
@inline Base.convert(::Type{UInt128}, x::uint32x4) = unsafe_load(Ptr{UInt128}(pointer_from_objref(Ref(x))))
96-
@inline UInt128(x::uint32x4) = convert(UInt128, x)
97-
@inline uint32x4(x::UInt128) = convert(uint32x4, x)
98-
@inline Base.convert(::Type{uint32x4}, x::Union{Signed, Unsigned}) = convert(uint32x4, UInt128(x))
99-
@inline Base.convert(::Type{T}, x::uint32x4) where T <: Union{Signed, Unsigned} = convert(T, UInt128(x))
100-
101-
@inline function uint32x4(bytes::Vararg{UInt32, 4})
102-
bytes_prepped = bytes
103-
@static if LITTLE_ENDIAN
104-
bytes_prepped = reverse(bytes_prepped)
105-
end
106-
bytes_vec::uint32x4_lvec = VecElement.(bytes_prepped)
107-
return uint32x4(bytes_vec)
108-
end
109-
110-
@inline Base.zero(::Type{uint32x4}) = convert(uint32x4, zero(UInt128))
111-
@inline Base.xor(a::uint32x4, b::uint32x4) = llvmcall(
112-
"""%3 = xor <4 x i32> %1, %0
113-
ret <4 x i32> %3""",
114-
uint32x4_lvec, Tuple{uint32x4_lvec, uint32x4_lvec},
115-
a.data, b.data,
116-
) |> uint32x4
117-
11871
# Raw NEON instrinsics, provided by FEAT_AES
11972
const ARM_AESE_LLVM_INTRINSIC = "llvm.$LLVM_ARCH_STRING.crypto.aese"
12073
@inline _vaese(a::uint8x16, b::uint8x16) = ccall(

0 commit comments

Comments
 (0)