Skip to content

Commit 538c66f

Browse files
authored
Support higher order diff (#117)
* Support higher order diff * add tests * Update test_calculus.jl * Update calculus.jl * Generalise some routines
1 parent f241cf7 commit 538c66f

File tree

5 files changed

+44
-33
lines changed

5 files changed

+44
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "QuasiArrays"
22
uuid = "c4ea9172-b204-11e9-377d-29865faadc5c"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.11.9"
4+
version = "0.12"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/QuasiArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Base, LinearAlgebra, LazyArrays, ArrayLayouts, DomainSets, FillArrays, Sta
33
import Base: getindex, size, axes, axes1, length, ==, isequal, iterate, CartesianIndices, LinearIndices,
44
Indices, IndexStyle, getindex, setindex!, parent, vec, convert, similar, copy, copyto!, zero,
55
map, eachindex, eltype, first, last, firstindex, lastindex, in, reshape, permutedims, all,
6-
isreal, iszero, isempty, empty, isapprox, fill!, getproperty, showarg
6+
isreal, iszero, isone, isempty, empty, isapprox, fill!, getproperty, showarg
77
import Base: @_inline_meta, DimOrInd, OneTo, @_propagate_inbounds_meta, @_noinline_meta,
88
DimsInteger, error_if_canonical_getindex, @propagate_inbounds, _return_type,
99
safe_tail, front, tail, _getindex, _maybe_reshape, index_ndims, _unsafe_getindex,

src/calculus.jl

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,40 +51,37 @@ cumsum_size(::NTuple{N,Integer}, A, dims) where N = error("Not implemented")
5151
# diff
5252
####
5353

54-
@inline diff(a::AbstractQuasiArray; dims::Integer=1) = diff_layout(MemoryLayout(a), a, dims)
55-
function diff_layout(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVector, dims...)
54+
@inline diff(a::AbstractQuasiArray, order...; dims::Integer=1) = diff_layout(MemoryLayout(a), a, order...; dims)
55+
function diff_layout(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVecOrMat, order...; dims=1)
5656
a = arguments(LAY, V)
57-
*(diff(a[1]), tail(a)...)
57+
dims == 1 || throw(ArgumentError("cannot differentiate a vector along dimension $dims"))
58+
*(diff(a[1], order...), tail(a)...)
5859
end
5960

60-
function diff_layout(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiMatrix, dims=1)
61-
a = arguments(LAY, V)
62-
@assert dims == 1 #for type stability, for now
63-
# if dims == 1
64-
*(diff(a[1]), tail(a)...)
65-
# else
66-
# *(front(a)..., diff(a[end]; dims=dims))
67-
# end
61+
diff_layout(::MemoryLayout, A, order...; dims...) = diff_size(size(A), A, order...; dims...)
62+
diff_size(sz, a; dims...) = error("diff not implemented for $(typeof(a))")
63+
function diff_size(sz, a, order; dims...)
64+
order < 0 && throw(ArgumentError("order must be non-negative"))
65+
order == 0 && return a
66+
isone(order) ? diff(a) : diff(diff(a), order-1)
6867
end
6968

70-
diff_layout(::MemoryLayout, A, dims...) = diff_size(size(A), A, dims...)
71-
diff_size(sz, a, dims...) = error("diff not implemented for $(typeof(a))")
72-
7369
diff(x::Inclusion; dims::Integer=1) = ones(eltype(x), diffaxes(x))
74-
diff(c::AbstractQuasiFill{<:Any,1}; dims::Integer=1) = zeros(eltype(c), diffaxes(axes(c,1)))
75-
function diff(c::AbstractQuasiFill{<:Any,2}; dims::Integer=1)
70+
diff(x::Inclusion, order::Int; dims::Integer=1) = fill(ifelse(isone(order), one(eltype(x)), zero(eltype(x))), diffaxes(x,order))
71+
diff(c::AbstractQuasiFill{<:Any,1}, order...; dims::Integer=1) = zeros(eltype(c), diffaxes(axes(c,1),order...))
72+
function diff(c::AbstractQuasiFill{<:Any,2}, order...; dims::Integer=1)
7673
a,b = axes(c)
7774
if dims == 1
78-
zeros(eltype(c), diffaxes(a), b)
75+
zeros(eltype(c), diffaxes(a, order...), b)
7976
else
80-
zeros(eltype(c), a, diffaxes(b))
77+
zeros(eltype(c), a, diffaxes(b, order...))
8178
end
8279
end
8380

8481

85-
diffaxes(a::Inclusion{<:Any,<:AbstractVector}) = Inclusion(a.domain[1:end-1])
86-
diffaxes(a::OneTo) = oneto(length(a)-1)
87-
diffaxes(a) = a # default is differentiation does not change axes
82+
diffaxes(a::Inclusion{<:Any,<:AbstractVector}, order=1) = Inclusion(a.domain[1:end-order])
83+
diffaxes(a::OneTo, order=1) = oneto(length(a)-order)
84+
diffaxes(a, order...) = a # default is differentiation does not change axes
8885

8986
diff(b::QuasiVector; dims::Integer=1) = QuasiVector(diff(b.parent) ./ diff(b.axes[1]), (diffaxes(axes(b,1)),))
9087
function diff(A::QuasiMatrix; dims::Integer=1)

src/quasibroadcast.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,13 @@ LazyArrays._broadcast_mul_mul((A,b)::Tuple{AbstractQuasiMatrix,AbstractQuasiVect
186186
# support (A .* B) * y
187187
_broadcasted_mul(a::Tuple{Number,Vararg{Any}}, b::AbstractQuasiVector) = (first(a)*sum(b), _broadcasted_mul(tail(a), b)...)
188188
_broadcasted_mul(a::Tuple{Number,Vararg{Any}}, B::AbstractQuasiMatrix) = (first(a)*sum(B; dims=1), _broadcasted_mul(tail(a), B)...)
189-
_broadcasted_mul(a::Tuple{AbstractQuasiVector,Vararg{Any}}, b::AbstractQuasiVector) = (first(a)*sum(b), _broadcasted_mul(tail(a), b)...)
190-
_broadcasted_mul(a::Tuple{AbstractQuasiVector,Vararg{Any}}, B::AbstractQuasiMatrix) = (first(a)*sum(B; dims=1), _broadcasted_mul(tail(a), B)...)
191-
_broadcasted_mul(A::Tuple{AbstractQuasiMatrix,Vararg{Any}}, b::AbstractQuasiVector) = (axes(first(A),2) == Base.OneTo(1) ? first(A)*sum(b) : (first(A)*b), _broadcasted_mul(tail(A), b)...)
192-
_broadcasted_mul(A::Tuple{AbstractQuasiMatrix,Vararg{Any}}, B::AbstractQuasiMatrix) = (axes(first(A),2) == Base.OneTo(1) ? first(A)*sum(B; dims=1) : (first(A)*B), _broadcasted_mul(tail(A), B)...)
189+
_broadcasted_mul(a::Tuple{AbstractQuasiVector,Vararg{Any}}, b::AbstractQuasiOrVector) = (first(a)*sum(b), _broadcasted_mul(tail(a), b)...)
190+
_broadcasted_mul(a::Tuple{AbstractQuasiVector,Vararg{Any}}, B::AbstractQuasiOrMatrix) = (first(a)*sum(B; dims=1), _broadcasted_mul(tail(a), B)...)
191+
_broadcasted_mul(A::Tuple{AbstractQuasiMatrix,Vararg{Any}}, b::AbstractQuasiOrVector) = (axes(first(A),2) == Base.OneTo(1) ? first(A)*sum(b) : (first(A)*b), _broadcasted_mul(tail(A), b)...)
192+
_broadcasted_mul(A::Tuple{AbstractQuasiMatrix,Vararg{Any}}, B::AbstractQuasiOrMatrix) = (axes(first(A),2) == Base.OneTo(1) ? first(A)*sum(B; dims=1) : (first(A)*B), _broadcasted_mul(tail(A), B)...)
193193
_broadcasted_mul(A::AbstractQuasiMatrix, b::Tuple{Number,Vararg{Any}}) = (sum(A; dims=2)*first(b)[1], _broadcasted_mul(A, tail(b))...)
194-
_broadcasted_mul(A::AbstractQuasiMatrix, b::Tuple{Union{AbstractVector,AbstractQuasiVector},Vararg{Any}}) = (size(first(b),1) == 1 ? (sum(A; dims=2)*first(b)[1]) : (A*first(b)), _broadcasted_mul(A, tail(b))...)
195-
_broadcasted_mul(A::AbstractQuasiMatrix, B::Tuple{Union{AbstractMatrix,AbstractQuasiMatrix},Vararg{Any}}) = (size(first(B),1) == 1 ? (sum(A; dims=2) * first(B)) : (A * first(B)), _broadcasted_mul(A, tail(B))...)
194+
_broadcasted_mul(A::AbstractQuasiMatrix, b::Tuple{AbstractQuasiOrVector,Vararg{Any}}) = (size(first(b),1) == 1 ? (sum(A; dims=2)*first(b)[1]) : (A*first(b)), _broadcasted_mul(A, tail(b))...)
195+
_broadcasted_mul(A::AbstractQuasiMatrix, B::Tuple{AbstractQuasiOrMatrix,Vararg{Any}}) = (size(first(B),1) == 1 ? (sum(A; dims=2) * first(B)) : (A * first(B)), _broadcasted_mul(A, tail(B))...)
196196

197197

198198

test/test_calculus.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,25 @@ using QuasiArrays, IntervalSets, Test
3030

3131
@testset "Diff" begin
3232
x = range(0, 1; length=10_000)
33-
@test diff(Inclusion(x)) == ones(Inclusion(x[1:end-1]))
34-
@test diff(ones(Inclusion(x))) == zeros(Inclusion(x[1:end-1]))
33+
@test diff(Inclusion(x)) == diff(Inclusion(x),1) == ones(Inclusion(x[1:end-1]))
34+
@test diff(Inclusion(x),2) == diff(diff(Inclusion(x))) == zeros(Inclusion(x[1:end-2]))
35+
@test diff(ones(Inclusion(x))) == diff(ones(Inclusion(x)),1) == zeros(Inclusion(x[1:end-1]))
36+
@test diff(ones(Inclusion(x)),2) == diff(diff(ones(Inclusion(x)))) == zeros(Inclusion(x[1:end-2]))
3537

3638
@test diff(ones(Inclusion(x), Inclusion(x))) == zeros(Inclusion(x[1:end-1]), Inclusion(x))
39+
@test diff(ones(Inclusion(x), Inclusion(x)), 2) == zeros(Inclusion(x[1:end-2]), Inclusion(x))
3740
@test diff(ones(Inclusion(x), Inclusion(x)); dims=2) == zeros(Inclusion(x), Inclusion(x[1:end-1]))
41+
@test diff(ones(Inclusion(x), Inclusion(x)), 2; dims=2) == zeros(Inclusion(x), Inclusion(x[1:end-2]))
3842

3943
b = QuasiVector(exp.(x), x)
4044

4145
@test diff(b) b[Inclusion(x[1:end-1])] atol=1E-2
46+
@test diff(b,2) b[Inclusion(x[1:end-2])] atol=1E-1
4247

4348

4449
A = QuasiArray(randn(3,2), (1:0.5:2,0:0.5:0.5))
4550
@test diff(A; dims=1)[:,0] == diff(A[:,0])
51+
@test diff(A,2; dims=1)[:,0] == diff(diff(A[:,0]))
4652
@test diff(A; dims=2)[1,:] == diff(A[1,:])
4753

4854
@testset "* diff" begin
@@ -57,10 +63,18 @@ using QuasiArrays, IntervalSets, Test
5763

5864
@testset "Interval" begin
5965
@test diff(Inclusion(0.0..1)) ones(Inclusion(0.0..1))
60-
@test diff(ones(Inclusion(0.0..1))) zeros(Inclusion(0.0..1))
61-
@test diff(ones(Inclusion(0.0..1), Base.OneTo(3))) zeros(Inclusion(0.0..1), Base.OneTo(3))
66+
@test diff(Inclusion(0.0..1),1) fill(1.0,Inclusion(0.0..1))
67+
@test diff(Inclusion(0.0..1),2) fill(0.0,Inclusion(0.0..1))
68+
@test diff(ones(Inclusion(0.0..1))) diff(ones(Inclusion(0.0..1)),1) diff(ones(Inclusion(0.0..1)),2) zeros(Inclusion(0.0..1))
69+
@test diff(ones(Inclusion(0.0..1), Base.OneTo(3))) diff(ones(Inclusion(0.0..1), Base.OneTo(3)),2) zeros(Inclusion(0.0..1), Base.OneTo(3))
6270
@test diff(ones(Inclusion(0.0..1), Base.OneTo(3)); dims=2) zeros(Inclusion(0.0..1), Base.OneTo(2))
6371
@test diff(ones(Base.OneTo(3), Inclusion(0.0..1))) zeros(Base.OneTo(2), Inclusion(0.0..1))
6472
@test diff(ones(Base.OneTo(3), Inclusion(0.0..1)); dims=2) zeros(Base.OneTo(3), Inclusion(0.0..1))
6573
end
74+
75+
@testset "Incomplete" begin
76+
struct IncompleteQuasiArray <: AbstractQuasiVector{Int} end
77+
Base.axes(::IncompleteQuasiArray) = (Base.OneTo(3),)
78+
@test_throws ErrorException diff(IncompleteQuasiArray())
79+
end
6680
end

0 commit comments

Comments
 (0)