Skip to content

Commit e05a16c

Browse files
committed
overload Base.kron, add tests
1 parent 7168ba4 commit e05a16c

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

src/tensor.jl

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#
2-
"""
3-
$SIGNATURES
4-
2+
TENSOR_PROD_DOC = """
53
Computes the lazy pairwise Kronecker product, or tensor product,
64
operator of `AbstractMatrix`, and `AbstractSciMLOperator` subtypes.
75
Calling `⊗(ops...)` is equivalent to `Base.kron(ops...)`. Fast
@@ -10,11 +8,18 @@ product operator.
108
119
```
1210
TensorProductOperator(A, B) = A ⊗ B
11+
TensorProductOperator(A, B, C) = A ⊗ B ⊗ C
1312
1413
(A ⊗ B)(u) = vec(B * reshape(u, M, N) * transpose(A))
1514
```
1615
where `M = size(B, 2)`, and `N = size(A, 2)`
1716
"""
17+
18+
"""
19+
$SIGNATURES
20+
21+
$TENSOR_PROD_DOC
22+
"""
1823
struct TensorProductOperator{T,O,C} <: AbstractSciMLOperator{T}
1924
ops::O
2025
cache::C
@@ -54,18 +59,7 @@ TensorProductOperator(ii1::IdentityOperator, ii2::IdentityOperator) = IdentityOp
5459
"""
5560
$SIGNATURES
5661
57-
Computes the lazy pairwise Kronecker product, or tensor product,
58-
operator of `AbstractMatrix`, and `AbstractSciMLOperator` subtypes.
59-
Calling `⊗(ops...)` is equivalent to `Base.kron(ops...)`. Fast
60-
operator evaluation is performed without forming the full tensor
61-
product operator.
62-
63-
```
64-
TensorProductOperator(A, B) = A ⊗ B
65-
66-
(A ⊗ B)(u) = vec(B * reshape(u, M, N) * transpose(A))
67-
```
68-
where `M = size(B, 2)`, and `N = size(A, 2)`
62+
$TENSOR_PROD_DOC
6963
"""
7064
(ops::Union{AbstractMatrix,AbstractSciMLOperator}...) = TensorProductOperator(ops...)
7165

@@ -82,8 +76,7 @@ Base.kron(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = TensorProductOpe
8276
Base.kron(A::AbstractMatrix, B::AbstractSciMLOperator) = TensorProductOperator(A, B)
8377
Base.kron(A::AbstractSciMLOperator, B::AbstractMatrix) = TensorProductOperator(A, B)
8478

85-
# convert to matrix
86-
Base.kron(ops::AbstractSciMLOperator...) = kron(convert.(AbstractMatrix, ops)...)
79+
Base.kron(ops::AbstractSciMLOperator...) = TensorProductOperator(ops...)
8780

8881
function Base.convert(::Type{AbstractMatrix}, L::TensorProductOperator)
8982
kron(convert.(AbstractMatrix, L.ops)...)

test/matrix.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using SciMLOperators, LinearAlgebra
22
using Random
33

4-
using SciMLOperators: InvertibleOperator, InvertedOperator,
4+
using SciMLOperators: InvertibleOperator, InvertedOperator, , AbstractSciMLOperator
55
using FFTW
66

77
Random.seed!(0)
@@ -263,6 +263,18 @@ for square in [false, true] #for K in [1, K]
263263
AB = kron(A, B)
264264
ABC = kron(A, B, C)
265265

266+
# test Base.kron overload
267+
# ensure kron(mat, mat) is not a TensorProductOperator
268+
@test !isa(AB, AbstractSciMLOperator)
269+
@test !isa(ABC, AbstractSciMLOperator)
270+
271+
# test Base.kron overload
272+
_A = rand(N, N)
273+
@test kron(_A, MatrixOperator(_A)) isa TensorProductOperator
274+
@test kron(MatrixOperator(_A), _A) isa TensorProductOperator
275+
276+
@test kron(MatrixOperator(_A), MatrixOperator(_A)) isa TensorProductOperator
277+
266278
# Inputs
267279
u2 = rand(n1*n2, K)
268280
u3 = rand(n1*n2*n3, K)

0 commit comments

Comments
 (0)