@@ -2,9 +2,25 @@ struct Product{A,B}
2
2
a:: A
3
3
b:: B
4
4
end
5
+ function Base. size (p:: Product )
6
+ M = size (p. a, 1 )
7
+ (M, Base. tail (size (p. b))... )
8
+ end
9
+ @inline Base. length (p:: Product ) = prod (size (p))
10
+ @inline Base. broadcastable (p:: Product ) = p
11
+ @inline Base. ndims (p:: Type{Product{A,B}} ) where {A,B} = ndims (B)
12
+
13
+ Base. Broadcast. _broadcast_getindex_eltype (:: Product{A,B} ) where {T, A <: AbstractVecOrMat{T} , B <: AbstractVecOrMat{T} } = T
14
+ function Base. Broadcast. _broadcast_getindex_eltype (p:: Product )
15
+ promote_type (
16
+ Base. Broadcast. _broadcast_getindex_eltype (p. a),
17
+ Base. Broadcast. _broadcast_getindex_eltype (p. b)
18
+ )
19
+ end
20
+
5
21
6
22
@inline ∗ (a:: A , b:: B ) where {A,B} = Product {A,B} (a, b)
7
- @inline Base. Broadcast. Broadcasted (:: typeof (∗ ), a:: A , b:: B ) where {A, B} = Product {A,B} (a, b)
23
+ @inline Base. Broadcast. broadcasted (:: typeof (∗ ), a:: A , b:: B ) where {A, B} = Product {A,B} (a, b)
8
24
# TODO : Need to make this handle A or B being (1 or 2)-D broadcast objects.
9
25
function add_broadcast! (
10
26
ls:: LoopSet , mC:: Symbol , bcname:: Symbol , loopsyms:: Vector{Symbol} ,
@@ -19,17 +35,29 @@ function add_broadcast!(
19
35
20
36
k = gensym (:k )
21
37
ls. loops[k] = Loop (k, K)
22
- m = loopsyms[1 ]; n = loopsyms[2 ];
38
+ m = loopsyms[1 ];
39
+ if ndims (B) == 1
40
+ bloopsyms = Symbol[k]
41
+ cloopsyms = Symbol[m]
42
+ reductdeps = Symbol[m, k]
43
+ elseif ndims (B) == 2
44
+ n = loopsyms[2 ];
45
+ bloopsyms = Symbol[k,n]
46
+ cloopsyms = Symbol[m,n]
47
+ reductdeps = Symbol[m, k, n]
48
+ else
49
+ throw (" B must be a vector or matrix." )
50
+ end
23
51
# load A
24
52
# loadA = add_load!(ls, gensym(:A), productref(A, mA, m, k), elementbytes)
25
- loadA = add_broadcast! (ls, gensym (:A ), mA, [m,k], A, elementbytes)
53
+ loadA = add_broadcast! (ls, gensym (:A ), mA, Symbol [m,k], A, elementbytes)
26
54
# load B
27
- loadB = add_broadcast! (ls, gensym (:B ), mB, [k,n] , B, elementbytes)
55
+ loadB = add_broadcast! (ls, gensym (:B ), mB, bloopsyms , B, elementbytes)
28
56
# set Cₘₙ = 0
29
- setC = add_constant! (ls, 0.0 , Symbol[m, k] , mC, elementbytes)
57
+ setC = add_constant! (ls, 0.0 , cloopsyms , mC, elementbytes)
30
58
# compute Cₘₙ += Aₘₖ * Bₖₙ
31
59
reductop = Operation (
32
- ls, mC, elementbytes, :vmuladd , compute, Symbol[m, k, n] , Symbol[k], Operation[loadA, loadB, setC]
60
+ ls, mC, elementbytes, :vmuladd , compute, reductdeps , Symbol[k], Operation[loadA, loadB, setC]
33
61
)
34
62
pushop! (ls, reductop, mC)
35
63
end
102
130
# size of dest determines loops
103
131
@generated function vmaterialize! (
104
132
dest:: AbstractArray{T,N} , bc:: BC
105
- # ) where {T, N, BC <: Broadcasted}
106
- ) where {N, T, BC <: Broadcasted }
133
+ ) where {T, N, BC <: Broadcasted }
134
+ # ) where {N, T, BC <: Broadcasted}
107
135
# we have an N dimensional loop.
108
136
# need to construct the LoopSet
109
137
loopsyms = [gensym (:n ) for n ∈ 1 : N]
0 commit comments