|
3 | 3 |
|
4 | 4 | isdense(::Type{<:DenseArray}) = true
|
5 | 5 |
|
6 |
| -# """ |
7 |
| -# ShortVector{T} simply wraps a Vector{T}, but uses a different hash function that is faster for short vectors to support using it as the keys of a Dict. |
8 |
| -# This hash function scales O(N) with length of the vectors, so it is slow for long vectors. |
9 |
| -# """ |
10 |
| -# struct ShortVector{T} <: DenseVector{T} |
11 |
| -# data::Vector{T} |
12 |
| -# end |
13 |
| -# Base.@propagate_inbounds Base.getindex(x::ShortVector, I...) = x.data[I...] |
14 |
| -# Base.@propagate_inbounds Base.setindex!(x::ShortVector, v, I...) = x.data[I...] = v |
15 |
| -# @inbounds Base.length(x::ShortVector) = length(x.data) |
16 |
| -# @inbounds Base.size(x::ShortVector) = size(x.data) |
17 |
| -# @inbounds Base.strides(x::ShortVector) = strides(x.data) |
18 |
| -# @inbounds Base.push!(x::ShortVector, v) = push!(x.data, v) |
19 |
| -# @inbounds Base.append!(x::ShortVector, v) = append!(x.data, v) |
20 |
| -# function Base.hash(x::ShortVector, h::UInt) |
21 |
| -# @inbounds for n ∈ eachindex(x) |
22 |
| -# h = hash(x[n], h) |
23 |
| -# end |
24 |
| -# h |
25 |
| -# end |
26 |
| - |
27 |
| - |
| 6 | +""" |
| 7 | +ShortVector{T} simply wraps a Vector{T}, but uses a different hash function that is faster for short vectors to support using it as the keys of a Dict. |
| 8 | +This hash function scales O(N) with length of the vectors, so it is slow for long vectors. |
| 9 | +""" |
| 10 | +struct ShortVector{T} <: DenseVector{T} |
| 11 | + data::Vector{T} |
| 12 | +end |
| 13 | +Base.@propagate_inbounds Base.getindex(x::ShortVector, I...) = x.data[I...] |
| 14 | +Base.@propagate_inbounds Base.setindex!(x::ShortVector, v, I...) = x.data[I...] = v |
| 15 | +ShortVector{T}(::UndefInitializer, N::Integer) where {T} = ShortVector{T}(Vector{T}(undef, N)) |
| 16 | +@inbounds Base.length(x::ShortVector) = length(x.data) |
| 17 | +@inbounds Base.size(x::ShortVector) = size(x.data) |
| 18 | +@inbounds Base.strides(x::ShortVector) = strides(x.data) |
| 19 | +@inbounds Base.push!(x::ShortVector, v) = push!(x.data, v) |
| 20 | +@inbounds Base.append!(x::ShortVector, v) = append!(x.data, v) |
| 21 | +function Base.hash(x::ShortVector, h::UInt) |
| 22 | + @inbounds for n ∈ eachindex(x) |
| 23 | + h = hash(x[n], h) |
| 24 | + end |
| 25 | + h |
| 26 | +end |
| 27 | +function Base.isequal(a::ShortVector{T}, b::ShortVector{T}) where {T} |
| 28 | + length(a) == length(b) || return false |
| 29 | + @inbounds for i ∈ 1:length(a) |
| 30 | + a[i] === b[i] || return false |
| 31 | + end |
| 32 | + true |
| 33 | +end |
| 34 | +Base.convert(::Type{Vector}, sv::ShortVector) = sv.data |
| 35 | +Base.convert(::Type{Vector{T}}, sv::ShortVector{T}) where {T} = sv.data |
28 | 36 |
|
29 | 37 | @enum OperationType begin
|
30 | 38 | memload
|
@@ -183,22 +191,19 @@ struct LoopSet
|
183 | 191 | loops::Dict{Symbol,Loop} # sym === loops[sym].itersymbol
|
184 | 192 | opdict::Dict{Symbol,Operation}
|
185 | 193 | operations::Vector{Operation} # Split them to make it easier to iterate over just a subset
|
186 |
| - # computeops::Vector{Operation} |
187 |
| - # storeops::Vector{Operation} |
188 |
| - outer_reductions::Set{UInt} # IDs of reduction operations that need to be reduced at end. |
| 194 | + outer_reductions::Vector{UInt} # IDs of reduction operations that need to be reduced at end. |
189 | 195 | loop_order::LoopOrder
|
| 196 | + stridesets::Dict{ShortVector{Symbol},ShortVector{Symbol}} |
190 | 197 | preamble::Expr # TODO: add preamble to lowering
|
191 | 198 | end
|
192 | 199 | function LoopSet()
|
193 | 200 | LoopSet(
|
194 | 201 | Dict{Symbol,Loop}(),
|
195 | 202 | Dict{Symbol,Operation}(),
|
196 | 203 | Operation[],
|
197 |
| - # Operation[], |
198 |
| - # Operation[], |
199 |
| - # Set{UInt}(), |
200 |
| - Set{UInt}(), |
| 204 | + UInt[], #Set{UInt}(), |
201 | 205 | LoopOrder(),
|
| 206 | + Dict{ShortVector{Symbol},ShortVector{Symbol}}, |
202 | 207 | Expr(:block,)
|
203 | 208 | )
|
204 | 209 | end
|
@@ -271,8 +276,17 @@ function add_loop!(ls::LoopSet, looprange::Expr)
|
271 | 276 | end
|
272 | 277 | function add_load!(ls::LoopSet, indexed::Symbol, indices::AbstractVector)
|
273 | 278 | Ninds = length(indices)
|
274 |
| - |
275 |
| - |
| 279 | + inds = ShortVector{Symbol}(indices) |
| 280 | + nsets = length(ls.stridesets) |
| 281 | + get!(ls.stridesets, inds) do |
| 282 | + strides = ShortVector{Symbol}(undef, Ninds - 1) |
| 283 | + @inbounds for i ∈ 2:Ninds |
| 284 | + sᵢ = Symbol(:stride_, nsets, :_, inds[i]) |
| 285 | + strides[i-1] = sᵢ |
| 286 | + push!(ls.preamble, Expr(:(=), sᵢ, Expr(:call, :stride, indexed, i))) |
| 287 | + end |
| 288 | + strides |
| 289 | + end |
276 | 290 |
|
277 | 291 | end
|
278 | 292 | function add_load_getindex!(ls::LoopSet, ex::Expr)
|
|
0 commit comments