Skip to content

Commit a59eec9

Browse files
Merge pull request #200 from vpuri3/batch
[WIp] batch_dim kwarg
2 parents dd4a56c + 569cd50 commit a59eec9

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

src/func.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ function `adjoint_inverse`. All are assumed to have the same calling signature a
120120
below traits.
121121
122122
## Traits
123+
123124
Keyword arguments are used to set operator traits, which are assumed to be
124125
uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
125126
@@ -132,6 +133,7 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
132133
* `has_mul5` - `true` if the operator provides a five-argument `mul!` via the signature `op(v, u, p, t, α, β; <accepted_kwargs>)`. This trait is inferred if no value is provided.
133134
* `isconstant` - `true` if the operator is constant, and doesn't need to be updated via `update_coefficients[!]` during operator evaluation.
134135
* `islinear` - `true` if the operator is linear. Defaults to `false`.
136+
* `batch` - Boolean indicating if the input/output arrays comprise of batched vectors. If `true`, the last dimension of input/output arrays is considered to be the batch dimension and is not involved in size computation. For example, let `size(output), size(input) = (M, K), (N, K)`. If `batch = true`, then the second dimension is assumed to be the batch dimension, and the `size(F::FunctionOperator) = (M, N)`. If `batch = false`, then `size(F::FunctionOperator) = (M * K, M * K)`.
135137
* `ifcache` - Allocate cache arrays in constructor. Defaults to `true`. Cache can be generated afterwards by calling `cache_operator(L, input, output)`
136138
* `cache` - Pregenerated cache arrays for in-place evaluations. Expected to be of type and shape `(similar(input), similar(output),)`. The constructor generates cache if no values are provided. Cache generation by the constructor can be disabled by setting the kwarg `ifcache = false`.
137139
* `opnorm` - The norm of `op`. Can be a `Number`, or function `opnorm(p::Integer)`. Defaults to `nothing`.
@@ -159,6 +161,7 @@ function FunctionOperator(op,
159161
isconstant::Bool = false,
160162
islinear::Bool = false,
161163

164+
batch::Bool = false,
162165
ifcache::Bool = true,
163166
cache::Union{Nothing, NTuple{2}}=nothing,
164167

@@ -171,10 +174,22 @@ function FunctionOperator(op,
171174

172175
# store eltype of input/output for caching with ComposedOperator.
173176
eltypes = eltype.((input, output))
174-
sz = (size(output, 1), size(input, 1))
175177
T = isnothing(T) ? promote_type(eltypes...) : T
176178
t = isnothing(t) ? zero(real(T)) : t
177179

180+
@assert ndims(output) == ndims(input) """input/output arrays,
181+
($(typeof(input)), $(typeof(output))) provided to FunctionOperator
182+
do not have the same number of dimensions."""
183+
184+
_size = if batch
185+
# assume batches are in the last dimension
186+
sz_in = size(input)[1:end-1] |> prod
187+
sz_out = size(output)[1:end-1] |> prod
188+
(sz_out, sz_in)
189+
else
190+
(length(output), length(input))
191+
end
192+
178193
isinplace = if isnothing(isinplace)
179194
static_hasmethod(op, typeof((output, input, p, t)))
180195
else
@@ -235,7 +250,7 @@ function FunctionOperator(op,
235250
has_mul5 = has_mul5,
236251
ifcache = ifcache,
237252
T = T,
238-
size = sz,
253+
size = _size,
239254
eltypes = eltypes,
240255
accepted_kwargs = accepted_kwargs,
241256
kwargs = Dict{Symbol, Any}(),

test/func.jl

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ K = 12
6363

6464
@test op1' === op1
6565

66-
@test size(op1) == (N,N)
66+
@test size(op1) == (N*K,N*K)
6767
@test has_adjoint(op1)
6868
@test has_mul(op1)
6969
@test !has_mul!(op1)
7070
@test has_ldiv(op1)
7171
@test !has_ldiv!(op1)
7272

73-
@test size(op2) == (N,N)
73+
@test size(op2) == (N*K,N*K)
7474
@test has_adjoint(op2)
7575
@test has_mul(op2)
7676
@test has_mul!(op2)
@@ -100,6 +100,81 @@ K = 12
100100
v = copy(u); @test A \ v ldiv!(op2, u)
101101
end
102102

103+
@testset "Batch FunctionOperator" begin
104+
u = rand(N,K)
105+
p = nothing
106+
t = 0.0
107+
α = rand()
108+
β = rand()
109+
110+
A = rand(N,N) |> Symmetric
111+
F = lu(A)
112+
Ai = inv(A)
113+
114+
f1(u, p, t) = A * u
115+
f1i(u, p, t) = A \ u
116+
117+
f2(du, u, p, t) = mul!(du, A, u)
118+
f2(du, u, p, t, α, β) = mul!(du, A, u, α, β)
119+
f2i(du, u, p, t) = ldiv!(du, F, u)
120+
f2i(du, u, p, t, α, β) = mul!(du, Ai, u, α, β)
121+
# out of place
122+
op1 = FunctionOperator(f1, u, A*u;
123+
124+
op_inverse=f1i,
125+
126+
ifcache = false,
127+
batch = true,
128+
islinear=true,
129+
opnorm=true,
130+
issymmetric=true,
131+
ishermitian=true,
132+
isposdef=true,
133+
)
134+
135+
# in place
136+
op2 = FunctionOperator(f2, u, A*u;
137+
138+
op_inverse=f2i,
139+
140+
ifcache = false,
141+
batch = true,
142+
islinear=true,
143+
opnorm=true,
144+
issymmetric=true,
145+
ishermitian=true,
146+
isposdef=true,
147+
)
148+
149+
@test issquare(op1)
150+
@test issquare(op2)
151+
152+
@test islinear(op1)
153+
@test islinear(op2)
154+
155+
@test op1' === op1
156+
157+
@test size(op1) == (N,N)
158+
@test has_adjoint(op1)
159+
@test has_mul(op1)
160+
@test !has_mul!(op1)
161+
@test has_ldiv(op1)
162+
@test !has_ldiv!(op1)
163+
164+
@test size(op2) == (N,N)
165+
@test has_adjoint(op2)
166+
@test has_mul(op2)
167+
@test has_mul!(op2)
168+
@test has_ldiv(op2)
169+
@test has_ldiv!(op2)
170+
171+
@test !iscached(op1)
172+
@test !iscached(op2)
173+
174+
@test !op1.traits.has_mul5
175+
@test op2.traits.has_mul5
176+
end
177+
103178
@testset "FunctionOperator update test" begin
104179
u = rand(N,K)
105180
p = rand(N)

0 commit comments

Comments
 (0)