Skip to content

Commit b8a4d44

Browse files
committed
Reworked some internals to prepare for supporting indexing with an indexed vector of indices.
1 parent 09ff723 commit b8a4d44

File tree

6 files changed

+377
-201
lines changed

6 files changed

+377
-201
lines changed

src/broadcast.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,22 @@ function add_broadcast!(
104104
::Type{<:LowDimArray{D,T,N}}, elementbytes::Int = 8
105105
) where {D,T,N}
106106
fulldims = Union{Symbol,Int}[loopsyms[n] for n 1:N if D[n]]
107-
ref = ArrayReference(bcname, fulldims, Ref{Bool}(false))
108-
add_load!(ls, destname, ref, elementbytes)::Operation
107+
ref = ArrayReference(bcname, fulldims)
108+
add_simple_load!(ls, destname, ref, elementbytes )::Operation
109109
end
110110
function add_broadcast_adjoint_array!(
111111
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{A}, elementbytes::Int = 8
112112
) where {T,N,A<:AbstractArray{T,N}}
113113
parent = gensym(:parent)
114114
pushpreamble!(ls, Expr(:(=), parent, Expr(:call, :parent, bcname)))
115-
ref = ArrayReference(parent, Union{Symbol,Int}[loopsyms[N + 1 - n] for n 1:N], Ref{Bool}(false))
116-
add_load!( ls, destname, ref, elementbytes )::Operation
115+
ref = ArrayReference(parent, Union{Symbol,Int}[loopsyms[N + 1 - n] for n 1:N])
116+
add_simple_load!( ls, destname, ref, elementbytes )::Operation
117117
end
118118
function add_broadcast_adjoint_array!(
119119
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{<:AbstractVector}, elementbytes::Int = 8
120120
)
121-
ref = ArrayReference(bcname, Union{Symbol,Int}[loopsyms[2]], Ref{Bool}(false))
122-
add_load!( ls, destname, ref, elementbytes )
121+
ref = ArrayReference(bcname, Union{Symbol,Int}[loopsyms[2]])
122+
add_simple_load!( ls, destname, ref, elementbytes )
123123
end
124124
function add_broadcast!(
125125
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
@@ -137,7 +137,7 @@ function add_broadcast!(
137137
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
138138
::Type{<:AbstractArray{T,N}}, elementbytes::Int = 8
139139
) where {T,N}
140-
add_load!(ls, destname, ArrayReference(bcname, @view(loopsyms[1:N]), Ref{Bool}(false)), elementbytes)
140+
add_simple_load!(ls, destname, ArrayReference(bcname, @view(loopsyms[1:N])), elementbytes)
141141
end
142142
function add_broadcast!(
143143
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{T}, elementbytes::Int = 8
@@ -153,7 +153,7 @@ function add_broadcast!(
153153
inds = Vector{Union{Int,Symbol}}(undef, N+1)
154154
inds[1] = Symbol("##DISCONTIGUOUSSUBARRAY##")
155155
inds[2:end] .= @view(loopsyms[1:N])
156-
add_load!(ls, destname, ArrayReference(bcname, inds, Ref{Bool}(false)), elementbytes)
156+
add_simple_load!(ls, destname, ArrayReference(bcname, inds), elementbytes)
157157
end
158158
function add_broadcast!(
159159
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
@@ -201,7 +201,7 @@ end
201201
pushpreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest)))
202202
elementbytes = sizeof(T)
203203
add_broadcast!(ls, :dest, :bc, loopsyms, BC, elementbytes)
204-
add_store!(ls, :dest, ArrayReference(:dest, loopsyms, Ref{Bool}(false)), elementbytes)
204+
add_simple_store!(ls, :dest, ArrayReference(:dest, loopsyms), elementbytes)
205205
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
206206
q = lower(ls)
207207
push!(q.args, :dest)
@@ -226,7 +226,7 @@ end
226226
pushpreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest′)))
227227
elementbytes = sizeof(T)
228228
add_broadcast!(ls, :dest, :bc, loopsyms, BC, elementbytes)
229-
add_store!(ls, :dest, ArrayReference(:dest, reverse(loopsyms), Ref{Bool}(false)), elementbytes)
229+
add_simple_store!(ls, :dest, ArrayReference(:dest, reverse(loopsyms)), elementbytes)
230230
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
231231
q = lower(ls)
232232
push!(q.args, :dest′)

src/determinestrategy.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
# TODO: FIXME for general case
33
# wrong for transposed matrices, and certain views/SubArrays.
4-
unitstride(op::Operation, s) = first(op.ref.ref) === s
4+
unitstride(op::Operation, s) = first(getindices(op)) === s
55

66
function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int = op.elementbytes)
77
isconstant(op) && return 0.0, 0, 1
@@ -272,6 +272,7 @@ function solve_tilesize(
272272
end
273273

274274
function set_upstream_family!(adal::Vector{T}, op::Operation, val::T) where {T}
275+
adal[identifier(op)] == val && return # must already have been set
275276
adal[identifier(op)] = val
276277
for opp parents(op)
277278
set_upstream_family!(adal, opp, val)

0 commit comments

Comments
 (0)