Skip to content

Commit 9a08ccb

Browse files
S-D-Rmaleadt
authored andcommitted
Constructors for ConstantStruct and multidimensional ConstantArrays
Nested structs are currently not supported via the `ConstantStruct(value, packed::Core.Bool=false)`, `ConstantStruct(value, ctx::Context=GlobalContext(), packed::Core.Bool=false)` and `ConstantStruct(value, typ::LLVMType)` constructors. After some testing and looking at the generated LLVM IR of some C code using nested structs, it appears like nested structs need to be named. This is an issue because we can't know the LLVM name of a nested struct type via a Julia value. Of course we could define our own struct types for nested structs, but this leads to other issues like making sure the name does not clash, and duplicated definitions of the same struct (e.g. when calling `ConstStruct` multiple times with the same struct value). Nested structs can still be created via the "lower-level" constructors if necessary. As for multidimensional arrays: it doesn't seem to be possible to read from them via `LLVMGetElementAsConstant`, so this is not supported.
1 parent f77a992 commit 9a08ccb

File tree

2 files changed

+115
-6
lines changed

2 files changed

+115
-6
lines changed

src/core/value/constant.jl

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,22 @@ end
123123
identify(::Type{Value}, ::Val{API.LLVMConstantArrayValueKind}) = ConstantArray
124124
identify(::Type{Value}, ::Val{API.LLVMConstantDataArrayValueKind}) = ConstantArray
125125

126-
ConstantArray(typ::LLVMType, data::Vector{T}) where {T<:Constant} =
127-
ConstantArray(API.LLVMConstArray(typ, data, length(data)))
128-
ConstantArray(typ::IntegerType, data::Vector{T}) where {T<:Integer} =
129-
ConstantArray(typ, map(x->ConstantInt(convert(T,x),context(typ)), data))
130-
ConstantArray(typ::FloatingPointType, data::Vector{T}) where {T<:AbstractFloat} =
131-
ConstantArray(typ, map(x->ConstantFP(convert(T,x),context(typ)), data))
126+
function ConstantArray(typ::LLVMType, data::AbstractArray{T,N}) where {T<:Constant,N}
127+
if N == 1
128+
return ConstantArray(API.LLVMConstArray(typ, Array(data), length(data)))
129+
end
130+
131+
ca_vec = map(x->ConstantArray(typ, x), eachslice(data, dims=1))
132+
ca_typ = llvmtype(first(ca_vec))
133+
134+
return ConstantArray(API.LLVMConstArray(ca_typ, ca_vec, length(ca_vec)))
135+
end
136+
ConstantArray(typ::IntegerType, data::AbstractArray{T,N}) where {T<:Integer,N} =
137+
ConstantArray(typ, map(x->ConstantInt(typ, x), data))
138+
ConstantArray(typ::FloatingPointType, data::AbstractArray{T,N}) where {T<:AbstractFloat,N} =
139+
ConstantArray(typ, map(x->ConstantFP(typ, x), data))
132140

141+
# NOTE: getindex is not supported for multidimensionsal constant arrays
133142
Base.getindex(ca::ConstantArray, idx::Integer) =
134143
API.LLVMGetElementAsConstant(ca, idx-1)
135144
Base.length(ca::ConstantArray) = length(llvmtype(ca))
@@ -144,6 +153,52 @@ Base.convert(::Type{Array{T,1}}, ca::ConstantArray) where {T<:AbstractFloat} =
144153
end
145154
identify(::Type{Value}, ::Val{API.LLVMConstantStructValueKind}) = ConstantStruct
146155

156+
ConstantStruct(constant_vals::Vector{T}, packed::Core.Bool=false) where {T<:Constant} =
157+
ConstantStruct(API.LLVMConstStruct(constant_vals, length(constant_vals), convert(Bool, packed)))
158+
ConstantStruct(constant_vals::Vector{T}, ctx::Context, packed::Core.Bool=false) where {T<:Constant} =
159+
ConstantStruct(API.LLVMConstStructInContext(ctx, constant_vals, length(constant_vals), convert(Bool, packed)))
160+
ConstantStruct(constant_vals::Vector{T}, typ::LLVMType) where {T<:Constant} =
161+
ConstantStruct(API.LLVMConstNamedStruct(typ, constant_vals, length(constant_vals)))
162+
163+
function struct_to_constants(value, ctx::Context)
164+
constants = Vector{Constant}()
165+
166+
for fieldname in fieldnames(typeof(value))
167+
field = getfield(value, fieldname)
168+
169+
if isa(field, Core.Bool)
170+
typ = LLVM.Int1Type(ctx)
171+
push!(constants, ConstantInt(typ, Int(field)))
172+
elseif isa(field, Integer)
173+
push!(constants, ConstantInt(field, ctx))
174+
elseif isa(field, AbstractFloat)
175+
push!(constants, ConstantFP(field, ctx))
176+
else # TODO: nested structs?
177+
throw(ArgumentError("only structs with boolean, integer and floating point fields are allowed"))
178+
end
179+
end
180+
181+
return constants
182+
end
183+
184+
function ConstantStruct(value, packed::Core.Bool=false)
185+
isbits(value) || throw(ArgumentError("`value` must be isbits"))
186+
constants = struct_to_constants(value, GlobalContext())
187+
return ConstantStruct(constants, packed)
188+
end
189+
190+
function ConstantStruct(value, ctx::Context, packed::Core.Bool=false)
191+
isbits(value) || throw(ArgumentError("`value` must be isbits"))
192+
constants = struct_to_constants(value, ctx)
193+
return ConstantStruct(constants, ctx, packed)
194+
end
195+
196+
function ConstantStruct(value, typ::LLVMType)
197+
isbits(value) || throw(ArgumentError("`value` must be isbits"))
198+
constants = struct_to_constants(value, context(typ))
199+
return ConstantStruct(constants, typ)
200+
end
201+
147202
@checked struct ConstantVector <: ConstantAggregate
148203
ref::API.LLVMValueRef
149204
end

test/core.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
struct TestStruct
2+
x::Bool
3+
y::Int64
4+
z::Float16
5+
end
6+
17
@testset "core" begin
28

39
@testset "context" begin
@@ -393,6 +399,54 @@ Context() do ctx
393399
@test convert(Float32, ConstantFP(ca[1]))::Float32 == 1.1f0
394400
@test convert(Vector{Float32}, ca) == vec
395401
end
402+
let
403+
typ = LLVM.Int64Type(ctx)
404+
vec = fill(5, 3, 4, 5, 6)
405+
ca = ConstantArray(typ, vec)
406+
@test length(ca) == size(vec, 1)
407+
@test llvmtype(ca) == LLVM.ArrayType(LLVM.ArrayType(LLVM.ArrayType(LLVM.ArrayType(LLVM.Int64Type(ctx), 6), 5), 4), 3)
408+
# NOTE: can't test content of the array because the API does not support reading from multidimensional arrays
409+
end
410+
411+
end
412+
413+
@testset "struct constants" begin
414+
415+
let
416+
test_struct = TestStruct(true, -99, 1.5)
417+
constant_struct = ConstantStruct(test_struct, ctx)
418+
constant_struct_type = llvmtype(constant_struct)
419+
420+
@test typeof(constant_struct_type) == LLVM.StructType
421+
@test context(constant_struct_type) == ctx
422+
@test !ispacked(constant_struct_type)
423+
@test !isopaque(constant_struct_type)
424+
425+
@test collect(elements(constant_struct_type)) == [LLVM.Int1Type(ctx), LLVM.Int64Type(ctx), LLVM.HalfType(ctx)]
426+
427+
expected_operands = [
428+
ConstantInt(LLVM.Int1Type(ctx), Int(true)),
429+
ConstantInt(LLVM.Int64Type(ctx), -99),
430+
ConstantFP(LLVM.HalfType(ctx), 1.5)
431+
]
432+
@test collect(operands(constant_struct)) == expected_operands
433+
end
434+
let
435+
named_struct_type = LLVM.StructType("named_struct", ctx)
436+
elements!(named_struct_type, [LLVM.Int1Type(ctx), LLVM.Int64Type(ctx), LLVM.HalfType(ctx)], true)
437+
438+
test_struct = TestStruct(false, 52, -2.5)
439+
constant_struct = ConstantStruct(test_struct, named_struct_type)
440+
441+
@test llvmtype(constant_struct) == named_struct_type
442+
443+
expected_operands = [
444+
ConstantInt(LLVM.Int1Type(ctx), Int(false)),
445+
ConstantInt(LLVM.Int64Type(ctx), 52),
446+
ConstantFP(LLVM.HalfType(ctx), -2.5)
447+
]
448+
@test collect(operands(constant_struct)) == expected_operands
449+
end
396450

397451
end
398452
end

0 commit comments

Comments
 (0)