@@ -18,14 +18,27 @@ function Base.Broadcast._broadcast_getindex_eltype(p::Product)
18
18
)
19
19
end
20
20
21
+ # recursive_eltype(::Type{A}) where {T, A <: AbstractArray{T}} = T
22
+ # recursive_eltype(::Type{NTuple{N,T}}) where {N,T<:Union{Float32,Float64}} = T
23
+ # recursive_eltype(::Type{Float32}) = Float32
24
+ # recursive_eltype(::Type{Float64}) = Float64
25
+ # recursive_eltype(::Type{Tuple{T}}) where {T} = T
26
+ # recursive_eltype(::Type{Tuple{T1,T2}}) where {T1,T2} = promote_type(recursive_eltype(T1), recursive_eltype(T2))
27
+ # recursive_eltype(::Type{Tuple{T1,T2,T3}}) where {T1,T2,T3} = promote_type(recursive_eltype(T1), recursive_eltype(T2), recursive_eltype(T3))
28
+ # recursive_eltype(::Type{Tuple{T1,T2,T3,T4}}) where {T1,T2,T3,T4} = promote_type(recursive_eltype(T1), recursive_eltype(T2), recursive_eltype(T3), recursive_eltype(T4))
29
+ # recursive_eltype(::Type{Tuple{T1,T2,T3,T4,T5}}) where {T1,T2,T3,T4,T5} = promote_type(recursive_eltype(T1), recursive_eltype(T2), recursive_eltype(T3), recursive_eltype(T4), recursive_eltype(T5))
30
+
31
+ # function recursive_eltype(::Type{Broadcasted{S,A,F,ARGS}}) where {S,A,F,ARGS}
32
+ # recursive_eltype(ARGS)
33
+ # end
21
34
22
35
@inline ∗ (a:: A , b:: B ) where {A,B} = Product {A,B} (a, b)
23
36
@inline Base. Broadcast. broadcasted (:: typeof (∗ ), a:: A , b:: B ) where {A, B} = Product {A,B} (a, b)
24
37
# TODO : Need to make this handle A or B being (1 or 2)-D broadcast objects.
25
38
function add_broadcast! (
26
39
ls:: LoopSet , mC:: Symbol , bcname:: Symbol , loopsyms:: Vector{Symbol} ,
27
40
:: Type{Product{A,B}} , elementbytes:: Int = 8
28
- ) where {T,A, B}
41
+ ) where {A, B}
29
42
K = gensym (:K )
30
43
mA = gensym (:Aₘₖ )
31
44
mB = gensym (:Bₖₙ )
@@ -54,7 +67,12 @@ function add_broadcast!(
54
67
# load B
55
68
loadB = add_broadcast! (ls, gensym (:B ), mB, bloopsyms, B, elementbytes)
56
69
# set Cₘₙ = 0
57
- setC = add_constant! (ls, 0.0 , cloopsyms, mC, elementbytes)
70
+ # setC = add_constant!(ls, zero(promote_type(recursive_eltype(A), recursive_eltype(B))), cloopsyms, mC, elementbytes)
71
+ setC = if elementbytes == 4
72
+ add_constant! (ls, 0f0 , cloopsyms, mC, elementbytes)
73
+ else # if elementbytes == 4
74
+ add_constant! (ls, 0.0 , cloopsyms, mC, elementbytes)
75
+ end
58
76
# compute Cₘₙ += Aₘₖ * Bₖₙ
59
77
reductop = Operation (
60
78
ls, mC, elementbytes, :vmuladd , compute, reductdeps, Symbol[k], Operation[loadA, loadB, setC]
@@ -118,7 +136,7 @@ function add_broadcast!(
118
136
argname = gensym (:arg )
119
137
pushpreamble! (ls, Expr (:(= ), argname, Expr (:ref , bcargs, i)))
120
138
# dynamic dispatch
121
- parent = add_broadcast! (ls, gensym (:temp ), argname, loopsyms, arg):: Operation
139
+ parent = add_broadcast! (ls, gensym (:temp ), argname, loopsyms, arg, elementbytes ):: Operation
122
140
pushparent! (parents, deps, reduceddeps, parent)
123
141
end
124
142
op = Operation (
130
148
# size of dest determines loops
131
149
@generated function vmaterialize! (
132
150
dest:: AbstractArray{T,N} , bc:: BC
133
- ) where {T, N, BC <: Broadcasted }
151
+ ) where {T <: Union{Float32,Float64} , N, BC <: Broadcasted }
134
152
# ) where {N, T, BC <: Broadcasted}
135
153
# we have an N dimensional loop.
136
154
# need to construct the LoopSet
143
161
push! (sizes. args, Nsym)
144
162
end
145
163
pushpreamble! (ls, Expr (:(= ), sizes, Expr (:call , :size , :dest )))
146
- add_broadcast! (ls, :dest , :bc , loopsyms, BC)
147
- add_store! (ls, :dest , ArrayReference (:dest , loopsyms, Ref {Bool} (false )))
164
+ elementbytes = sizeof (T)
165
+ add_broadcast! (ls, :dest , :bc , loopsyms, BC, elementbytes)
166
+ add_store! (ls, :dest , ArrayReference (:dest , loopsyms, Ref {Bool} (false )), elementbytes)
148
167
resize! (ls. loop_order, num_loops (ls)) # num_loops may be greater than N, eg Product
149
168
q = lower (ls)
150
169
push! (q. args, :dest )
0 commit comments