Skip to content

Commit dd4a56c

Browse files
Merge pull request #199 from vpuri3/docs
Overload Base.kron, clean up tensor prod benchmark
2 parents 24609db + e05a16c commit dd4a56c

File tree

4 files changed

+57
-41
lines changed

4 files changed

+57
-41
lines changed

benchmarks/tensor.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,48 @@
11
using SciMLOperators, LinearAlgebra, BenchmarkTools
22
using SciMLOperators: IdentityOperator,
33

4-
Id = IdentityOperator{12}()
5-
A = rand(12,12)
6-
B = rand(12,12)
7-
C = rand(12,12)
4+
N = 12
5+
K = 100
6+
Id = IdentityOperator(N)
7+
A = rand(N,N)
8+
B = rand(N,N)
9+
C = rand(N,N)
810

911
println("#===============================#")
1012
println("2D Tensor Products")
1113
println("#===============================#")
1214

1315
println("⊗(A, B)")
1416

15-
u = rand(12^2, 100)
16-
v = rand(12^2, 100)
17+
u = rand(N^2, K)
18+
v = rand(N^2, K)
1719

1820
T = (A, B)
1921
T = cache_operator(T, u)
2022

23+
@btime *($T, $u)
2124
@btime mul!($v, $T, $u)
2225

2326
println("⊗(I, B)")
2427

25-
u = rand(12^2, 100)
26-
v = rand(12^2, 100)
28+
u = rand(N^2, K)
29+
v = rand(N^2, K)
2730

2831
T = (Id, B)
2932
T = cache_operator(T, u)
3033

34+
@btime *($T, $u)
3135
@btime mul!($v, $T, $u)
3236

3337
println("⊗(A, I)")
3438

35-
u = rand(12^2, 100)
36-
v = rand(12^2, 100)
39+
u = rand(N^2, K)
40+
v = rand(N^2, K)
3741

3842
T = (A, Id)
3943
T = cache_operator(T, u)
4044

45+
@btime *($T, $u)
4146
@btime mul!($v, $T, $u)
4247

4348
println("#===============================#")
@@ -46,25 +51,25 @@ println("#===============================#")
4651

4752
println("⊗(⊗(A, B), C)")
4853

49-
u = rand(12^3, 100)
50-
v = rand(12^3, 100)
54+
u = rand(N^3, K)
55+
v = rand(N^3, K)
5156

5257
T = ((A, B), C)
5358
T = cache_operator(T, u)
5459

55-
mul!(v, T, u) # dunny
60+
@btime *($T, $u)
5661
@btime mul!($v, $T, $u); #
5762

5863
println("⊗(A, ⊗(B, C))")
5964

60-
u = rand(12^3, 100)
61-
v = rand(12^3, 100)
65+
u = rand(N^3, K)
66+
v = rand(N^3, K)
6267

6368
T = (A, (B, C))
6469
T = cache_operator(T, u)
6570

66-
mul!(v, T, u) # dunny
71+
@btime *($T, $u)
6772
@btime mul!($v, $T, $u); #
6873

6974
println("#===============================#")
70-
nothing
75+
#

src/func.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ function FunctionOperator(op,
8686
FunctionOperator(op, input, output; kwargs...)
8787
end
8888

89-
# TODO: document constructor and revisit design as needed (e.g. for "accepted_kwargs")
9089
"""
9190
$(SIGNATURES)
9291
@@ -127,7 +126,6 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
127126
* `p` - Prototype of parameter struct passed to the operator during evaluation, i.e. `L(u, p, t)`. `p` is set to `nothing` if no value is provided.
128127
* `t` - Protype of scalar time variable passed to the operator during evaluation. `t` is set to `zero(T)` if no value is provided.
129128
* `accepted_kwargs` - `Tuple` of `Symbol`s corresponding to the keyword arguments accepted by `op*`, and `update_coefficients[!]`. For example, if `op` accepts kwarg `scale`, as in `op(u, p, t; scale)`, then `accepted_kwargs = (:scale,)`.
130-
131129
* `T` - `eltype` of the operator. If no value is provided, the constructor inferrs the value from types of `input`, and `output`
132130
* `isinplace` - `true` if the operator can be used is a mutating way with in-place allocations. This trait is inferred if no value is provided.
133131
* `outofplace` - `true` if the operator can be used is a non-mutating way with in-place allocations. This trait is inferred if no value is provided.
@@ -440,8 +438,6 @@ has_mul!(::FunctionOperator{iip}) where{iip} = iip
440438
has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
441439
has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothing)
442440

443-
# TODO - FunctionOperator, Base.conj, transpose
444-
445441
# operator application
446442
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip}
447443
L.op(u, L.p, L.t; L.traits.kwargs...)

src/tensor.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#
2-
"""
3-
$SIGNATURES
4-
2+
TENSOR_PROD_DOC = """
53
Computes the lazy pairwise Kronecker product, or tensor product,
64
operator of `AbstractMatrix`, and `AbstractSciMLOperator` subtypes.
75
Calling `⊗(ops...)` is equivalent to `Base.kron(ops...)`. Fast
@@ -10,11 +8,18 @@ product operator.
108
119
```
1210
TensorProductOperator(A, B) = A ⊗ B
11+
TensorProductOperator(A, B, C) = A ⊗ B ⊗ C
1312
1413
(A ⊗ B)(u) = vec(B * reshape(u, M, N) * transpose(A))
1514
```
1615
where `M = size(B, 2)`, and `N = size(A, 2)`
1716
"""
17+
18+
"""
19+
$SIGNATURES
20+
21+
$TENSOR_PROD_DOC
22+
"""
1823
struct TensorProductOperator{T,O,C} <: AbstractSciMLOperator{T}
1924
ops::O
2025
cache::C
@@ -54,26 +59,24 @@ TensorProductOperator(ii1::IdentityOperator, ii2::IdentityOperator) = IdentityOp
5459
"""
5560
$SIGNATURES
5661
57-
Computes the lazy pairwise Kronecker product, or tensor product,
58-
operator of `AbstractMatrix`, and `AbstractSciMLOperator` subtypes.
59-
Calling `⊗(ops...)` is equivalent to `Base.kron(ops...)`. Fast
60-
operator evaluation is performed without forming the full tensor
61-
product operator.
62-
63-
```
64-
TensorProductOperator(A, B) = A ⊗ B
65-
66-
(A ⊗ B)(u) = vec(B * reshape(u, M, N) * transpose(A))
67-
```
68-
where `M = size(B, 2)`, and `N = size(A, 2)`
62+
$TENSOR_PROD_DOC
6963
"""
7064
(ops::Union{AbstractMatrix,AbstractSciMLOperator}...) = TensorProductOperator(ops...)
7165

72-
# TODO - overload Base.kron for tensor product operators
73-
#Base.kron(ops::Union{AbstractMatrix,AbstractSciMLOperator}...) = TensorProductOperator(ops...)
66+
"""
67+
$SIGNATURES
68+
69+
Construct a lazy representation of the Kronecker product `A ⊗ B`. One of the
70+
two factors can be an `AbstractMatrix`, which is then promoted to a
71+
`MatrixOperator` automatically. To avoid fallback to the generic
72+
[`Base.kron`](@ref), at least one of `A` and `B` must be an
73+
`AbstractSciMLOperator`.
74+
"""
75+
Base.kron(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = TensorProductOperator(A, B)
76+
Base.kron(A::AbstractMatrix, B::AbstractSciMLOperator) = TensorProductOperator(A, B)
77+
Base.kron(A::AbstractSciMLOperator, B::AbstractMatrix) = TensorProductOperator(A, B)
7478

75-
# convert to matrix
76-
Base.kron(ops::AbstractSciMLOperator...) = kron(convert.(AbstractMatrix, ops)...)
79+
Base.kron(ops::AbstractSciMLOperator...) = TensorProductOperator(ops...)
7780

7881
function Base.convert(::Type{AbstractMatrix}, L::TensorProductOperator)
7982
kron(convert.(AbstractMatrix, L.ops)...)

test/matrix.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using SciMLOperators, LinearAlgebra
22
using Random
33

4-
using SciMLOperators: InvertibleOperator, InvertedOperator,
4+
using SciMLOperators: InvertibleOperator, InvertedOperator, , AbstractSciMLOperator
55
using FFTW
66

77
Random.seed!(0)
@@ -263,6 +263,18 @@ for square in [false, true] #for K in [1, K]
263263
AB = kron(A, B)
264264
ABC = kron(A, B, C)
265265

266+
# test Base.kron overload
267+
# ensure kron(mat, mat) is not a TensorProductOperator
268+
@test !isa(AB, AbstractSciMLOperator)
269+
@test !isa(ABC, AbstractSciMLOperator)
270+
271+
# test Base.kron overload
272+
_A = rand(N, N)
273+
@test kron(_A, MatrixOperator(_A)) isa TensorProductOperator
274+
@test kron(MatrixOperator(_A), _A) isa TensorProductOperator
275+
276+
@test kron(MatrixOperator(_A), MatrixOperator(_A)) isa TensorProductOperator
277+
266278
# Inputs
267279
u2 = rand(n1*n2, K)
268280
u3 = rand(n1*n2*n3, K)

0 commit comments

Comments
 (0)