Skip to content

Commit 69c214b

Browse files
committed
Rework ConstantArray.
Add multidimensional indexing, add some convenience constructors.
1 parent 8b2e04a commit 69c214b

File tree

1 file changed

+78
-25
lines changed

1 file changed

+78
-25
lines changed

src/core/value/constant.jl

Lines changed: 78 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Base.convert(::Type{T}, val::ConstantFP) where {T<:AbstractFloat} =
9393
convert(T, API.LLVMConstRealGetDouble(val, Ref{API.LLVMBool}()))
9494

9595

96-
## aggregate
96+
## aggregate zero
9797

9898
export ConstantAggregateZero
9999

@@ -106,51 +106,89 @@ identify(::Type{Value}, ::Val{API.LLVMConstantAggregateZeroValueKind}) = Constan
106106
# ConstantAggregateZero value directly, but values can occur through calls to LLVMConstNull
107107

108108

109-
## constant expressions
110-
111-
export ConstantExpr, ConstantAggregate, ConstantArray, ConstantStruct, ConstantVector, InlineAsm
112-
113-
@checked struct ConstantExpr <: Constant
114-
ref::API.LLVMValueRef
115-
end
116-
identify(::Type{Value}, ::Val{API.LLVMConstantExprValueKind}) = ConstantExpr
109+
## regular aggregate
117110

118111
abstract type ConstantAggregate <: Constant end
119112

113+
# arrays
114+
120115
@checked struct ConstantArray <: ConstantAggregate
121116
ref::API.LLVMValueRef
122117
end
123118
identify(::Type{Value}, ::Val{API.LLVMConstantArrayValueKind}) = ConstantArray
124119
identify(::Type{Value}, ::Val{API.LLVMConstantDataArrayValueKind}) = ConstantArray
125120

126-
function ConstantArray(typ::LLVMType, data::AbstractArray{T,N}) where {T<:Constant,N}
121+
# generic constructor taking an array of constants
122+
function ConstantArray(data::AbstractArray{<:Constant,N},
123+
typ::LLVMType=llvmtype(first(data))) where {N}
124+
@assert all(x->x==typ, llvmtype.(data))
125+
127126
if N == 1
128127
return ConstantArray(API.LLVMConstArray(typ, Array(data), length(data)))
129128
end
130129

131130
if VERSION >= v"1.1"
132-
ca_vec = map(x->ConstantArray(typ, x), eachslice(data, dims=1))
131+
ca_vec = map(x->ConstantArray(x, typ), eachslice(data, dims=1))
133132
else
134-
ca_vec = map(x->ConstantArray(typ, x), (view(data, i, ntuple(d->(:), N-1)...) for i in axes(data, 1)))
133+
ca_vec = map(x->ConstantArray(x, typ), (view(data, i, ntuple(d->(:), N-1)...) for i in axes(data, 1)))
135134
end
136135
ca_typ = llvmtype(first(ca_vec))
137136

138137
return ConstantArray(API.LLVMConstArray(ca_typ, ca_vec, length(ca_vec)))
139138
end
140-
ConstantArray(typ::IntegerType, data::AbstractArray{T,N}) where {T<:Integer,N} =
141-
ConstantArray(typ, map(x->ConstantInt(typ, x), data))
142-
ConstantArray(typ::FloatingPointType, data::AbstractArray{T,N}) where {T<:AbstractFloat,N} =
143-
ConstantArray(typ, map(x->ConstantFP(typ, x), data))
144-
145-
# NOTE: getindex is not supported for multidimensionsal constant arrays
146-
Base.getindex(ca::ConstantArray, idx::Integer) =
147-
API.LLVMGetElementAsConstant(ca, idx-1)
148-
Base.length(ca::ConstantArray) = length(llvmtype(ca))
139+
140+
# shorthands with arrays of plain Julia data
141+
ConstantArray(data::AbstractArray{T,N}, ctx::Context=GlobalContext()) where {T<:Integer,N} =
142+
ConstantArray(ConstantInt.(data, Ref(ctx)), IntType(sizeof(T)*8, ctx))
143+
ConstantArray(data::AbstractArray{Float16,N}, ctx::Context=GlobalContext()) where {N} =
144+
ConstantArray(ConstantFP.(data, Ref(ctx)), HalfType(ctx))
145+
ConstantArray(data::AbstractArray{Float32,N}, ctx::Context=GlobalContext()) where {N} =
146+
ConstantArray(ConstantFP.(data, Ref(ctx)), FloatType(ctx))
147+
ConstantArray(data::AbstractArray{Float64,N}, ctx::Context=GlobalContext()) where {N} =
148+
ConstantArray(ConstantFP.(data, Ref(ctx)), DoubleType(ctx))
149+
150+
# convert back to known array types
151+
function Base.collect(ca::ConstantArray)
152+
constants = Array{Value}(undef, size(ca))
153+
for I in CartesianIndices(size(ca))
154+
@inbounds constants[I] = ca[Tuple(I)...]
155+
end
156+
return constants
157+
end
158+
159+
160+
# array interface
149161
Base.eltype(ca::ConstantArray) = eltype(llvmtype(ca))
150-
Base.convert(::Type{Array{T,1}}, ca::ConstantArray) where {T<:Integer} =
151-
[convert(T,ConstantInt(ca[i])) for i in 1:length(ca)]
152-
Base.convert(::Type{Array{T,1}}, ca::ConstantArray) where {T<:AbstractFloat} =
153-
[convert(T,ConstantFP(ca[i])) for i in 1:length(ca)]
162+
function Base.size(ca::ConstantArray)
163+
dims = Int[]
164+
typ = llvmtype(ca)
165+
while typ isa ArrayType
166+
push!(dims, length(typ))
167+
typ = eltype(typ)
168+
end
169+
return Tuple(dims)
170+
end
171+
Base.length(ca::ConstantArray) = prod(size(ca))
172+
Base.axes(ca::ConstantArray) = Base.OneTo.(size(ca))
173+
174+
function Base.getindex(ca::ConstantArray, idx::Integer...)
175+
# multidimensional arrays are represented by arrays of arrays,
176+
# which we need to 'peel back' by looking at the operand sets.
177+
# for the final dimension, we use LLVMGetElementAsConstant
178+
@boundscheck Base.checkbounds_indices(Base.Bool, axes(ca), idx) ||
179+
throw(BoundsError(ca, idx))
180+
I = CartesianIndices(size(ca))[idx...]
181+
for i in Tuple(I)
182+
if isempty(operands(ca))
183+
ca = LLVM.Value(API.LLVMGetElementAsConstant(ca, i-1))
184+
else
185+
ca = (Base.@_propagate_inbounds_meta; operands(ca)[i])
186+
end
187+
end
188+
return ca
189+
end
190+
191+
# structs
154192

155193
@checked struct ConstantStruct <: ConstantAggregate
156194
ref::API.LLVMValueRef
@@ -203,11 +241,26 @@ function ConstantStruct(value, typ::LLVMType)
203241
return ConstantStruct(constants, typ)
204242
end
205243

244+
# vectors
245+
206246
@checked struct ConstantVector <: ConstantAggregate
207247
ref::API.LLVMValueRef
208248
end
209249
identify(::Type{Value}, ::Val{API.LLVMConstantVectorValueKind}) = ConstantVector
210250

251+
252+
## constant expressions
253+
254+
export ConstantExpr, ConstantAggregate, ConstantArray, ConstantStruct, ConstantVector, InlineAsm
255+
256+
@checked struct ConstantExpr <: Constant
257+
ref::API.LLVMValueRef
258+
end
259+
identify(::Type{Value}, ::Val{API.LLVMConstantExprValueKind}) = ConstantExpr
260+
261+
262+
## inline assembly
263+
211264
@checked struct InlineAsm <: Constant
212265
ref::API.LLVMValueRef
213266
end

0 commit comments

Comments
 (0)