Skip to content

Commit f99621d

Browse files
committed
Updated tests, fixed broadcasting product example.
1 parent 7d5418f commit f99621d

File tree

5 files changed

+79
-228
lines changed

5 files changed

+79
-228
lines changed

src/broadcast.jl

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,25 @@ struct Product{A,B}
22
a::A
33
b::B
44
end
5+
function Base.size(p::Product)
6+
M = size(p.a, 1)
7+
(M, Base.tail(size(p.b))...)
8+
end
9+
@inline Base.length(p::Product) = prod(size(p))
10+
@inline Base.broadcastable(p::Product) = p
11+
@inline Base.ndims(p::Type{Product{A,B}}) where {A,B} = ndims(B)
12+
13+
Base.Broadcast._broadcast_getindex_eltype(::Product{A,B}) where {T, A <: AbstractVecOrMat{T}, B <: AbstractVecOrMat{T}} = T
14+
function Base.Broadcast._broadcast_getindex_eltype(p::Product)
15+
promote_type(
16+
Base.Broadcast._broadcast_getindex_eltype(p.a),
17+
Base.Broadcast._broadcast_getindex_eltype(p.b)
18+
)
19+
end
20+
521

622
@inline (a::A, b::B) where {A,B} = Product{A,B}(a, b)
7-
@inline Base.Broadcast.Broadcasted(::typeof(), a::A, b::B) where {A, B} = Product{A,B}(a, b)
23+
@inline Base.Broadcast.broadcasted(::typeof(), a::A, b::B) where {A, B} = Product{A,B}(a, b)
824
# TODO: Need to make this handle A or B being (1 or 2)-D broadcast objects.
925
function add_broadcast!(
1026
ls::LoopSet, mC::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
@@ -19,17 +35,29 @@ function add_broadcast!(
1935

2036
k = gensym(:k)
2137
ls.loops[k] = Loop(k, K)
22-
m = loopsyms[1]; n = loopsyms[2];
38+
m = loopsyms[1];
39+
if ndims(B) == 1
40+
bloopsyms = Symbol[k]
41+
cloopsyms = Symbol[m]
42+
reductdeps = Symbol[m, k]
43+
elseif ndims(B) == 2
44+
n = loopsyms[2];
45+
bloopsyms = Symbol[k,n]
46+
cloopsyms = Symbol[m,n]
47+
reductdeps = Symbol[m, k, n]
48+
else
49+
throw("B must be a vector or matrix.")
50+
end
2351
# load A
2452
# loadA = add_load!(ls, gensym(:A), productref(A, mA, m, k), elementbytes)
25-
loadA = add_broadcast!(ls, gensym(:A), mA, [m,k], A, elementbytes)
53+
loadA = add_broadcast!(ls, gensym(:A), mA, Symbol[m,k], A, elementbytes)
2654
# load B
27-
loadB = add_broadcast!(ls, gensym(:B), mB, [k,n], B, elementbytes)
55+
loadB = add_broadcast!(ls, gensym(:B), mB, bloopsyms, B, elementbytes)
2856
# set Cₘₙ = 0
29-
setC = add_constant!(ls, 0.0, Symbol[m, k], mC, elementbytes)
57+
setC = add_constant!(ls, 0.0, cloopsyms, mC, elementbytes)
3058
# compute Cₘₙ += Aₘₖ * Bₖₙ
3159
reductop = Operation(
32-
ls, mC, elementbytes, :vmuladd, compute, Symbol[m, k, n], Symbol[k], Operation[loadA, loadB, setC]
60+
ls, mC, elementbytes, :vmuladd, compute, reductdeps, Symbol[k], Operation[loadA, loadB, setC]
3361
)
3462
pushop!(ls, reductop, mC)
3563
end
@@ -102,8 +130,8 @@ end
102130
# size of dest determines loops
103131
@generated function vmaterialize!(
104132
dest::AbstractArray{T,N}, bc::BC
105-
# ) where {T, N, BC <: Broadcasted}
106-
) where {N, T, BC <: Broadcasted}
133+
) where {T, N, BC <: Broadcasted}
134+
# ) where {N, T, BC <: Broadcasted}
107135
# we have an N dimensional loop.
108136
# need to construct the LoopSet
109137
loopsyms = [gensym(:n) for n 1:N]

src/determinestrategy.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,12 @@ function depchain_cost!(
105105
end
106106
rt, sl
107107
end
108-
108+
function parentsnotreduction(op::Operation)
109+
for opp parents(op)
110+
isreduction(opp) && return false
111+
end
112+
return true
113+
end
109114
function determine_unroll_factor(
110115
ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol = first(order)
111116
)
@@ -116,7 +121,7 @@ function determine_unroll_factor(
116121
# The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
117122
num_reductions = 0#sum(isreduction, operations(ls))
118123
for op operations(ls)
119-
if isreduction(op) & iscompute(op)
124+
if isreduction(op) & iscompute(op) && parentsnotreduction(op)
120125
num_reductions += 1
121126
end
122127
end

src/graphs.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ end
146146

147147
function Operation(
148148
ls::LoopSet, variable, elementbytes, instruction,
149-
node_type, dependencies, reduced_deps, parents, ref
149+
node_type, dependencies, reduced_deps, parents, ref = NOTAREFERENCE
150150
)
151151
Operation(
152152
length(operations(ls)), variable, elementbytes, instruction,
@@ -283,6 +283,16 @@ function mergesetv!(s1::AbstractVector{T}, s2::AbstractVector{T}) where {T}
283283
end
284284
nothing
285285
end
286+
function mergesetdiffv!(
287+
s1::AbstractVector{T},
288+
s2::AbstractVector{T},
289+
s3::AbstractVector{T}
290+
) where {T}
291+
for s s2
292+
s s3 && addsetv!(s1, s)
293+
end
294+
nothing
295+
end
286296
function setdiffv!(s3::AbstractVector{T}, s1::AbstractVector{T}, s2::AbstractVector{T}) where {T}
287297
for s s1
288298
(s s2) || (s s3 && push!(s3, s))
@@ -311,7 +321,7 @@ function add_constant!(ls::LoopSet, var, deps::Vector{Symbol}, sym::Symbol = gen
311321
end
312322
function pushparent!(parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, parent::Operation)
313323
push!(parents, parent)
314-
mergesetv!(deps, loopdependencies(parent))
324+
mergesetdiffv!(deps, loopdependencies(parent), reduceddependencies(parent))
315325
if !(isload(parent) || isconstant(parent))
316326
mergesetv!(reduceddeps, reduceddependencies(parent))
317327
end

src/operations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct Operation
9090
# numerical_metadata::Vector{Int} # stride of -1 indicates dynamic
9191
# symbolic_metadata::Vector{Symbol}
9292
function Operation(
93-
identifier,
93+
identifier::Int,
9494
variable,
9595
elementbytes,
9696
instruction,

0 commit comments

Comments
 (0)