Skip to content

Commit 313374c

Browse files
committed
Update style
1 parent 24a1371 commit 313374c

File tree

3 files changed

+55
-57
lines changed

3 files changed

+55
-57
lines changed

src/cartesianproduct.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,28 @@ end
6666
function CartesianProductUnitRange(a, b)
6767
return CartesianProductUnitRange(a × b)
6868
end
69-
to_range(a::AbstractUnitRange) = a
70-
to_range(i::Integer) = Base.OneTo(i)
71-
cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b))
69+
to_product_indices(a::AbstractUnitRange) = a
70+
to_product_indices(i::Integer) = Base.OneTo(i)
71+
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
7272
function cartesianrange(p::CartesianPair)
73-
p′ = to_range(p.a) × to_range(p.b)
73+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
7474
return cartesianrange(p′)
7575
end
7676
function cartesianrange(p::CartesianProduct)
77-
p′ = to_range(p.a) × to_range(p.b)
77+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
7878
return cartesianrange(p′, Base.OneTo(length(p′)))
7979
end
8080
function cartesianrange(p::CartesianPair, range::AbstractUnitRange)
81-
p′ = to_range(p.a) × to_range(p.b)
81+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
8282
return cartesianrange(p′, range)
8383
end
8484
function cartesianrange(p::CartesianProduct, range::AbstractUnitRange)
85-
p′ = to_range(p.a) × to_range(p.b)
85+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
8686
return CartesianProductUnitRange(p′, range)
8787
end
8888

8989
function Base.axes(r::CartesianProductUnitRange)
90-
return (CartesianProductUnitRange(r.product, only(axes(r.range))),)
90+
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
9191
end
9292

9393
using Base.Broadcast: DefaultArrayStyle
@@ -96,12 +96,12 @@ for f in (:+, :-)
9696
function Broadcast.broadcasted(
9797
::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer
9898
)
99-
return CartesianProductUnitRange(r.product, $f.(r.range, x))
99+
return CartesianProductUnitRange(cartesianproduct(r), $f.(unproduct(r), x))
100100
end
101101
function Broadcast.broadcasted(
102102
::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange
103103
)
104-
return CartesianProductUnitRange(r.product, $f.(x, r.range))
104+
return CartesianProductUnitRange(cartesianproduct(r), $f.(x, unproduct(r)))
105105
end
106106
end
107107
end

src/fillarrays/matrixalgebrakit_truncate.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,23 @@ function MatrixAlgebraKit.findtruncated(
2424
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
2525
)
2626
I = findtruncated(Vector(values), strategy.strategy)
27-
prods = collect(only(axes(values)).product)[I]
28-
I_data = unique(map(x -> x.a, prods))
27+
prods = collect(cartesianproduct(only(axes(values))))[I]
28+
I_data = unique(map(arg1, prods))
2929
# Drop truncations that occur within the identity.
3030
I_data = filter(I_data) do i
31-
return count(x -> x.a == i, prods) == length(values.a)
31+
return count(x -> arg1(x) == i, prods) == length(arg1(values))
3232
end
3333
return (:) × I_data
3434
end
3535
function MatrixAlgebraKit.findtruncated(
3636
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
3737
)
3838
I = findtruncated(Vector(values), strategy.strategy)
39-
prods = collect(only(axes(values)).product)[I]
40-
I_data = unique(map(x -> x.b, prods))
39+
prods = collect(cartesianproduct(only(axes(values))))[I]
40+
I_data = unique(map(x -> arg2(x), prods))
4141
# Drop truncations that occur within the identity.
4242
I_data = filter(I_data) do i
43-
return count(x -> x.b == i, prods) == length(values.b)
43+
return count(x -> arg2(x) == i, prods) == length(arg2(values))
4444
end
4545
return I_data × (:)
4646
end

src/kroneckerarray.jl

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ arg2(a::KroneckerArray) = a.b
2424

2525
using Adapt: Adapt, adapt
2626
_adapt(to, a::AbstractArray) = adapt(to, a)
27-
Adapt.adapt_structure(to, a::KroneckerArray) = _adapt(to, a.a) _adapt(to, a.b)
27+
Adapt.adapt_structure(to, a::KroneckerArray) = _adapt(to, arg1(a)) _adapt(to, arg2(a))
2828

2929
# Allows extra customization, like for `FillArrays.Eye`.
3030
_copy(a::AbstractArray) = copy(a)
3131

3232
function Base.copy(a::KroneckerArray)
33-
return _copy(a.a) _copy(a.b)
33+
return _copy(arg1(a)) _copy(arg2(a))
3434
end
3535
function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
36-
copyto!(dest.a, src.a)
37-
copyto!(dest.b, src.b)
36+
copyto!(arg1(dest), arg1(src))
37+
copyto!(arg2(dest), arg2(src))
3838
return dest
3939
end
4040

@@ -53,8 +53,7 @@ function Base.similar(
5353
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
5454
},
5555
)
56-
return _similar(a, elt, map(ax -> ax.product.a, axs))
57-
_similar(a, elt, map(ax -> ax.product.b, axs))
56+
return _similar(a, elt, map(arg1, axs)) _similar(a, elt, map(arg2, axs))
5857
end
5958
function Base.similar(
6059
a::KroneckerArray,
@@ -63,26 +62,23 @@ function Base.similar(
6362
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
6463
},
6564
)
66-
return _similar(a.a, elt, map(ax -> ax.product.a, axs))
67-
_similar(a.b, elt, map(ax -> ax.product.b, axs))
65+
return _similar(arg1(a), elt, map(arg1, axs)) _similar(arg2(a), elt, map(arg2, axs))
6866
end
6967
function Base.similar(
7068
arrayt::Type{<:AbstractArray},
7169
axs::Tuple{
7270
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
7371
},
7472
)
75-
return _similar(arrayt, map(ax -> ax.product.a, axs))
76-
_similar(arrayt, map(ax -> ax.product.b, axs))
73+
return _similar(arrayt, map(arg1, axs)) _similar(arrayt, map(arg2, axs))
7774
end
7875
function Base.similar(
7976
arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}},
8077
axs::Tuple{
8178
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
8279
},
8380
) where {A,B}
84-
return _similar(A, map(ax -> ax.product.a, axs))
85-
_similar(B, map(ax -> ax.product.b, axs))
81+
return _similar(A, map(arg1, axs)) _similar(B, map(arg2, axs))
8682
end
8783
function Base.similar(
8884
::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, sz::Tuple{Int,Vararg{Int}}
@@ -115,39 +111,41 @@ kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
115111
kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b)
116112

117113
# Eagerly collect arguments to make more general on GPU.
118-
Base.collect(a::KroneckerArray) = kron_nd(collect(a.a), collect(a.b))
114+
Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
119115

120116
Base.zero(a::KroneckerArray) = zero(arg1(a)) zero(arg2(a))
121117

122118
function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
123119
return convert(Array{T,N}, collect(a))
124120
end
125121

126-
Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a))
122+
function Base.size(a::KroneckerArray)
123+
return ntuple(dim -> size(arg1(a), dim) * size(arg2(a), dim), ndims(a))
124+
end
127125

128126
function Base.axes(a::KroneckerArray)
129127
return ntuple(ndims(a)) do dim
130128
return CartesianProductUnitRange(
131-
axes(a.a, dim) × axes(a.b, dim), Base.OneTo(size(a, dim))
129+
axes(arg1(a), dim) × axes(arg2(a), dim), Base.OneTo(size(a, dim))
132130
)
133131
end
134132
end
135133

136-
arguments(a::KroneckerArray) = (a.a, a.b)
134+
arguments(a::KroneckerArray) = (arg1(a), arg2(a))
137135
arguments(a::KroneckerArray, n::Int) = arguments(a)[n]
138136
argument_types(a::KroneckerArray) = argument_types(typeof(a))
139137
argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A,B}}) where {A,B} = (A, B)
140138

141139
function Base.print_array(io::IO, a::KroneckerArray)
142-
Base.print_array(io, a.a)
140+
Base.print_array(io, arg1(a))
143141
println(io, "\n")
144-
Base.print_array(io, a.b)
142+
Base.print_array(io, arg2(a))
145143
return nothing
146144
end
147145
function Base.show(io::IO, a::KroneckerArray)
148-
show(io, a.a)
146+
show(io, arg1(a))
149147
print(io, "")
150-
show(io, a.b)
148+
show(io, arg2(a))
151149
return nothing
152150
end
153151

@@ -172,14 +170,14 @@ function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
172170
GPUArraysCore.assertscalar("getindex")
173171
# Code logic from Kronecker.jl:
174172
# https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105
175-
k, l = size(a.b)
176-
return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1]
173+
k, l = size(arg2(a))
174+
return arg1(a)[cld(i1, k), cld(i2, l)] * arg2(a)[(i1 - 1) % k + 1, (i2 - 1) % l + 1]
177175
end
178176

179177
function Base.getindex(a::KroneckerVector, i::Integer)
180178
GPUArraysCore.assertscalar("getindex")
181-
k = length(a.b)
182-
return a.a[cld(i, k)] * a.b[(i - 1) % k + 1]
179+
k = length(arg2(a))
180+
return arg1(a)[cld(i, k)] * arg2(a)[(i - 1) % k + 1]
183181
end
184182

185183
# Allow customizing for `FillArrays.Eye`.
@@ -191,49 +189,49 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) w
191189
return _getindex(arg1(a), arg1.(I)...) _getindex(arg2(a), arg2.(I)...)
192190
end
193191
# Fix ambigiuity error.
194-
Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[]
192+
Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[]
195193

196194
function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
197-
return a.a == b.a && a.b == b.b
195+
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
198196
end
199197
function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...)
200-
return isapprox(a.a, b.a; kwargs...) && isapprox(a.b, b.b; kwargs...)
198+
return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...)
201199
end
202200
function Base.iszero(a::KroneckerArray)
203-
return iszero(a.a) || iszero(a.b)
201+
return iszero(arg1(a)) || iszero(arg2(a))
204202
end
205203
function Base.isreal(a::KroneckerArray)
206-
return isreal(a.a) && isreal(a.b)
204+
return isreal(arg1(a)) && isreal(arg2(a))
207205
end
208206

209207
using DiagonalArrays: DiagonalArrays, diagonal
210208
function DiagonalArrays.diagonal(a::KroneckerArray)
211-
return diagonal(a.a) diagonal(a.b)
209+
return diagonal(arg1(a)) diagonal(arg2(a))
212210
end
213211

214212
Base.real(a::KroneckerArray{<:Real}) = a
215213
function Base.real(a::KroneckerArray)
216-
if iszero(imag(a.a)) || iszero(imag(a.b))
217-
return real(a.a) real(a.b)
218-
elseif iszero(real(a.a)) || iszero(real(a.b))
219-
return -imag(a.a) imag(a.b)
214+
if iszero(imag(arg1(a))) || iszero(imag(arg2(a)))
215+
return real(arg1(a)) real(arg2(a))
216+
elseif iszero(real(arg1(a))) || iszero(real(arg2(a)))
217+
return -imag(arg1(a)) imag(arg2(a))
220218
end
221-
return real(a.a) real(a.b) - imag(a.a) imag(a.b)
219+
return real(arg1(a)) real(arg2(a)) - imag(arg1(a)) imag(arg2(a))
222220
end
223221
Base.imag(a::KroneckerArray{<:Real}) = zero(a)
224222
function Base.imag(a::KroneckerArray)
225-
if iszero(imag(a.a)) || iszero(real(a.b))
226-
return real(a.a) imag(a.b)
227-
elseif iszero(real(a.a)) || iszero(imag(a.b))
228-
return imag(a.a) real(a.b)
223+
if iszero(imag(arg1(a))) || iszero(real(arg2(a)))
224+
return real(arg1(a)) imag(arg2(a))
225+
elseif iszero(real(arg1(a))) || iszero(imag(arg2(a)))
226+
return imag(arg1(a)) real(arg2(a))
229227
end
230-
return real(a.a) imag(a.b) + imag(a.a) real(a.b)
228+
return real(arg1(a)) imag(arg2(a)) + imag(arg1(a)) real(arg2(a))
231229
end
232230

233231
for f in [:transpose, :adjoint, :inv]
234232
@eval begin
235233
function Base.$f(a::KroneckerArray)
236-
return $f(a.a) $f(a.b)
234+
return $f(arg1(a)) $f(arg2(a))
237235
end
238236
end
239237
end

0 commit comments

Comments
 (0)