1
+ const FillVector{F,A} = Fill{F,1 ,A}
2
+ const FillMatrix{F,A} = Fill{F,2 ,A}
3
+ const OnesVector{F,A} = Ones{F,1 ,A}
4
+ const OnesMatrix{F,A} = Ones{F,2 ,A}
5
+ const ZerosVector{F,A} = Zeros{F,1 ,A}
6
+ const ZerosMatrix{F,A} = Zeros{F,2 ,A}
7
+
1
8
# # vec
2
9
3
10
vec (a:: Ones{T} ) where T = Ones {T} (length (a))
87
94
* (a:: Zeros{<:Any,2} , b:: Diagonal ) = mult_zeros (a, b)
88
95
* (a:: Diagonal , b:: Zeros{<:Any,1} ) = mult_zeros (a, b)
89
96
* (a:: Diagonal , b:: Zeros{<:Any,2} ) = mult_zeros (a, b)
90
- function * (a:: Diagonal , b:: AbstractFill{<:Any,2} )
97
+
98
+ # Cannot unify following methods for Diagonal
99
+ # due to ambiguity with general array mult. with fill
100
+ function * (a:: Diagonal , b:: FillMatrix )
101
+ size (a,2 ) == size (b,1 ) || throw (DimensionMismatch (" A has dimensions $(size (a)) but B has dimensions $(size (b)) " ))
102
+ a. diag .* b # use special broadcast
103
+ end
104
+ function * (a:: FillMatrix , b:: Diagonal )
105
+ size (a,2 ) == size (b,1 ) || throw (DimensionMismatch (" A has dimensions $(size (a)) but B has dimensions $(size (b)) " ))
106
+ a .* permutedims (b. diag) # use special broadcast
107
+ end
108
+ function * (a:: Diagonal , b:: OnesMatrix )
91
109
size (a,2 ) == size (b,1 ) || throw (DimensionMismatch (" A has dimensions $(size (a)) but B has dimensions $(size (b)) " ))
92
110
a. diag .* b # use special broadcast
93
111
end
94
- function * (a:: AbstractFill{<:Any,2} , b:: Diagonal )
112
+ function * (a:: OnesMatrix , b:: Diagonal )
95
113
size (a,2 ) == size (b,1 ) || throw (DimensionMismatch (" A has dimensions $(size (a)) but B has dimensions $(size (b)) " ))
96
114
a .* permutedims (b. diag) # use special broadcast
97
115
end
@@ -100,23 +118,61 @@ end
100
118
* (a:: Transpose{T, <:StridedMatrix{T}} , b:: Fill{T, 1} ) where T = reshape (sum (parent (a); dims= 1 ) .* b. value, size (parent (a), 2 ))
101
119
* (a:: StridedMatrix{T} , b:: Fill{T, 1} ) where T = reshape (sum (a; dims= 2 ) .* b. value, size (a, 1 ))
102
120
103
- function * (a:: Adjoint{T, <:StridedMatrix{T}} , b:: Fill{T, 2} ) where T
104
- fB = similar (parent (a), size (b, 1 ), size (b, 2 ))
105
- fill! (fB, b. value)
106
- return a* fB
121
+ function * (x:: AbstractMatrix , f:: FillMatrix )
122
+ axes (x, 2 ) ≠ axes (f, 1 ) &&
123
+ throw (DimensionMismatch (" Incompatible matrix multiplication dimensions" ))
124
+ m = size (f, 2 )
125
+ repeat (sum (x, dims= 2 ) * f. value, 1 , m)
126
+ end
127
+
128
+ function * (f:: FillMatrix , x:: AbstractMatrix )
129
+ axes (f, 2 ) ≠ axes (x, 1 ) &&
130
+ throw (DimensionMismatch (" Incompatible matrix multiplication dimensions" ))
131
+ m = size (f, 1 )
132
+ repeat (sum (x, dims= 1 ) * f. value, m, 1 )
107
133
end
108
134
109
- function * (a:: Transpose{T, <:StridedMatrix{T}} , b:: Fill{T, 2} ) where T
110
- fB = similar (parent (a), size (b, 1 ), size (b, 2 ))
111
- fill! (fB, b. value)
112
- return a* fB
135
+ function * (x:: AbstractMatrix , f:: OnesMatrix )
136
+ axes (x, 2 ) ≠ axes (f, 1 ) &&
137
+ throw (DimensionMismatch (" Incompatible matrix multiplication dimensions" ))
138
+ m = size (f, 2 )
139
+ repeat (sum (x, dims= 2 ) * one (eltype (f)), 1 , m)
113
140
end
114
141
115
- function * (a:: StridedMatrix{T} , b:: Fill{T, 2} ) where T
116
- fB = similar (a, size (b, 1 ), size (b, 2 ))
117
- fill! (fB, b. value)
118
- return a* fB
142
+ function * (f:: OnesMatrix , x:: AbstractMatrix )
143
+ axes (f, 2 ) ≠ axes (x, 1 ) &&
144
+ throw (DimensionMismatch (" Incompatible matrix multiplication dimensions" ))
145
+ m = size (f, 1 )
146
+ repeat (sum (x, dims= 1 ) * one (eltype (f)), m, 1 )
119
147
end
148
+
149
+ * (x:: FillMatrix , y:: FillMatrix ) = mult_fill (x, y)
150
+ * (x:: FillMatrix , y:: OnesMatrix ) = mult_fill (x, y)
151
+ * (x:: OnesMatrix , y:: FillMatrix ) = mult_fill (x, y)
152
+ * (x:: OnesMatrix , y:: OnesMatrix ) = mult_fill (x, y)
153
+ * (x:: ZerosMatrix , y:: OnesMatrix ) = mult_zeros (x, y)
154
+ * (x:: ZerosMatrix , y:: FillMatrix ) = mult_zeros (x, y)
155
+ * (x:: FillMatrix , y:: ZerosMatrix ) = mult_zeros (x, y)
156
+ * (x:: OnesMatrix , y:: ZerosMatrix ) = mult_zeros (x, y)
157
+
158
+ # function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
159
+ # fB = similar(parent(a), size(b, 1), size(b, 2))
160
+ # fill!(fB, b.value)
161
+ # return a*fB
162
+ # end
163
+
164
+ # function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
165
+ # fB = similar(parent(a), size(b, 1), size(b, 2))
166
+ # fill!(fB, b.value)
167
+ # return a*fB
168
+ # end
169
+
170
+ # function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
171
+ # fB = similar(a, size(b, 1), size(b, 2))
172
+ # fill!(fB, b.value)
173
+ # return a*fB
174
+ # end
175
+
120
176
function _adjvec_mul_zeros (a:: Adjoint{T} , b:: Zeros{S, 1} ) where {T, S}
121
177
la, lb = length (a), length (b)
122
178
if la ≠ lb
0 commit comments