Skip to content

Commit 5bba40f

Browse files
oschulzcscherrer
andauthored
Update to Static.jl v0.8 (#115)
* Update CI configuration * Update to Static.jl v0.8 * Increase package version to v0.14.6 * format * Fix docstring for PowerMeasure Co-authored-by: Chad Scherrer <[email protected]> * Rename dslength and dssize --------- Co-authored-by: Chad Scherrer <[email protected]>
1 parent 99a603b commit 5bba40f

File tree

9 files changed

+127
-70
lines changed

9 files changed

+127
-70
lines changed

.github/workflows/CI.yml

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ on:
66
push:
77
branches:
88
- master
9+
tags: '*'
910
pull_request:
1011

11-
1212
concurrency:
1313
# Skip intermediate builds: always.
1414
# Cancel intermediate builds: only if it is a pull request build.
@@ -19,59 +19,43 @@ jobs:
1919
test:
2020
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
2121
runs-on: ${{ matrix.os }}
22+
continue-on-error: ${{ matrix.version == 'nightly' }}
2223
strategy:
2324
fail-fast: false
2425
matrix:
2526
version:
2627
- '1.6'
27-
- '1.7'
28-
- '1.8'
28+
- '1'
29+
- 'nightly'
2930
os:
3031
- ubuntu-latest
3132
arch:
3233
- x64
34+
include:
35+
- version: 1
36+
os: ubuntu-latest
37+
arch: x86
38+
- version: 1
39+
os: macOS-latest
40+
arch: x64
41+
- version: 1
42+
os: windows-latest
43+
arch: x64
3344
steps:
34-
- uses: actions/checkout@v2
45+
- uses: actions/checkout@v3
3546
- uses: julia-actions/setup-julia@v1
3647
with:
3748
version: ${{ matrix.version }}
3849
arch: ${{ matrix.arch }}
39-
- uses: actions/cache@v1
40-
env:
41-
cache-name: cache-artifacts
42-
with:
43-
path: ~/.julia/artifacts
44-
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
45-
restore-keys: |
46-
${{ runner.os }}-test-${{ env.cache-name }}-
47-
${{ runner.os }}-test-
48-
${{ runner.os }}-
50+
- uses: julia-actions/cache@v1
4951
- uses: julia-actions/julia-buildpkg@v1
5052
- uses: julia-actions/julia-runtest@v1
53+
with:
54+
coverage: ${{ matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64' }}
5155
- uses: julia-actions/julia-processcoverage@v1
52-
- uses: codecov/codecov-action@v1
56+
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64'
57+
- uses: codecov/codecov-action@v3
58+
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64'
5359
with:
5460
file: lcov.info
55-
# docs:
56-
# name: Documentation
57-
# runs-on: ubuntu-latest
58-
# steps:
59-
# - uses: actions/checkout@v2
60-
# - uses: julia-actions/setup-julia@v1
61-
# with:
62-
# version: '1'
63-
# - run: |
64-
# julia --project=docs -e '
65-
# using Pkg
66-
# Pkg.develop(PackageSpec(path=pwd()))
67-
# Pkg.instantiate()'
68-
# - run: |
69-
# julia --project=docs -e '
70-
# using Documenter: DocMeta, doctest
71-
# using MeasureBase
72-
# DocMeta.setdocmeta!(MeasureBase, :DocTestSetup, :(using MeasureBase); recursive=true)
73-
# doctest(MeasureBase)'
74-
# - run: julia --project=docs docs/make.jl
75-
# env:
76-
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
77-
# DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
61+

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.14.5"
4+
version = "0.14.6"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -48,6 +48,6 @@ NaNMath = "0.3, 1"
4848
PrettyPrinting = "0.3, 0.4"
4949
Reexport = "1"
5050
SpecialFunctions = "2"
51-
Static = "0.5, 0.6"
51+
Static = "0.8"
5252
Tricks = "0.1"
5353
julia = "1.3"

src/MeasureBase.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using PrettyPrinting
3232
const Pretty = PrettyPrinting
3333

3434
using ChainRulesCore
35-
using FillArrays
35+
import FillArrays
3636
using Static
3737
using FunctionChains
3838

@@ -106,6 +106,7 @@ using Compat
106106

107107
using IrrationalConstants
108108

109+
include("static.jl")
109110
include("smf.jl")
110111
include("getdof.jl")
111112
include("transport.jl")

src/combinators/power.jl

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,68 @@
11
import Base
2-
using FillArrays: Fill
3-
# """
4-
# A power measure is a product of a measure with itself. The number of elements in
5-
# the product determines the dimensionality of the resulting support.
62

7-
# Note that power measures are only well-defined for integer powers.
3+
export PowerMeasure
84

9-
# The nth power of a measure μ can be written μ^x.
10-
# """
11-
# PowerMeasure{M,N,D} = ProductMeasure{Fill{M,N,D}}
5+
"""
6+
struct PowerMeasure{M,...} <: AbstractProductMeasure
127
13-
export PowerMeasure
8+
A power measure is a product of a measure with itself. The number of elements in
9+
the product determines the dimensionality of the resulting support.
10+
11+
Note that power measures are only well-defined for integer powers.
1412
13+
The nth power of a measure μ can be written μ^n.
14+
"""
1515
struct PowerMeasure{M,A} <: AbstractProductMeasure
1616
parent::M
1717
axes::A
1818
end
1919

20+
maybestatic_length::PowerMeasure) = prod(maybestatic_size(μ))
21+
maybestatic_size::PowerMeasure) = map(maybestatic_length, μ.axes)
22+
2023
function Pretty.tile::PowerMeasure)
2124
sz = length.(μ.axes)
2225
arg1 = Pretty.tile.parent)
2326
arg2 = Pretty.tile(length(sz) == 1 ? only(sz) : sz)
2427
return Pretty.pair_layout(arg1, arg2; sep = " ^ ")
2528
end
2629

30+
# ToDo: Make rand return static arrays for statically-sized power measures.
31+
32+
function _cartidxs(axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {N}
33+
CartesianIndices(map(_dynamic, axs))
34+
end
35+
2736
function Base.rand(
2837
rng::AbstractRNG,
2938
::Type{T},
3039
d::PowerMeasure{M},
3140
) where {T,M<:AbstractMeasure}
32-
map(CartesianIndices(d.axes)) do _
41+
map(_cartidxs(d.axes)) do _
3342
rand(rng, T, d.parent)
3443
end
3544
end
3645

3746
function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T}
38-
map(CartesianIndices(d.axes)) do _
47+
map(_cartidxs(d.axes)) do _
3948
rand(rng, d.parent)
4049
end
4150
end
4251

52+
@inline _pm_axes(sz::Tuple{Vararg{<:IntegerLike,N}}) where {N} = map(one_to, sz)
53+
@inline _pm_axes(axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {N} = axs
54+
4355
@inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T,N}
44-
a = axes(Fill{T,N}(x, sz))
45-
A = typeof(a)
46-
PowerMeasure{T,A}(x, a)
56+
PowerMeasure(x, _pm_axes(sz))
4757
end
4858

49-
marginals(d::PowerMeasure) = Fill(d.parent, d.axes)
59+
marginals(d::PowerMeasure) = fill_with(d.parent, d.axes)
5060

5161
function Base.:^::AbstractMeasure, dims::Tuple{Vararg{<:AbstractArray,N}}) where {N}
5262
powermeasure(μ, dims)
5363
end
5464

55-
Base.:^::AbstractMeasure, dims::Tuple) = powermeasure(μ, Base.OneTo.(dims))
65+
Base.:^::AbstractMeasure, dims::Tuple) = powermeasure(μ, one_to.(dims))
5666
Base.:^::AbstractMeasure, n) = powermeasure(μ, (n,))
5767

5868
# Base.show(io::IO, d::PowerMeasure) = print(io, d.parent, " ^ ", size(d.xs))
@@ -75,18 +85,15 @@ end
7585
end
7686
end
7787

78-
@inline function logdensity_def(
79-
d::PowerMeasure{M,Tuple{Base.OneTo{StaticInt{N}}}},
80-
x,
81-
) where {M,N}
88+
@inline function logdensity_def(d::PowerMeasure{M,Tuple{Static.SOneTo{N}}}, x) where {M,N}
8289
parent = d.parent
8390
sum(1:N) do j
8491
@inbounds logdensity_def(parent, x[j])
8592
end
8693
end
8794

8895
@inline function logdensity_def(
89-
d::PowerMeasure{M,NTuple{N,Base.OneTo{StaticInt{0}}}},
96+
d::PowerMeasure{M,NTuple{N,Static.SOneTo{0}}},
9097
x,
9198
) where {M,N}
9299
static(0.0)
@@ -110,7 +117,7 @@ end
110117

111118
@inline getdof::PowerMeasure) = getdof.parent) * prod(map(length, μ.axes))
112119

113-
@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Base.OneTo{StaticInt{0}}}}) where {N}
120+
@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N}
114121
static(0)
115122
end
116123

@@ -135,7 +142,7 @@ logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0)
135142

136143
# To avoid ambiguities
137144
function logdensity_def(
138-
::PowerMeasure{P,Tuple{Vararg{Base.OneTo{Static.StaticInt{0}},N}}},
145+
::PowerMeasure{P,Tuple{Vararg{Static.SOneTo{0},N}}},
139146
x,
140147
) where {P<:PrimitiveMeasure,N}
141148
static(0.0)

src/combinators/smart-constructors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737
###############################################################################
3838
# ProductMeasure
3939

40-
productmeasure(mar::Fill) = powermeasure(mar.value, mar.axes)
40+
productmeasure(mar::FillArrays.Fill) = powermeasure(mar.value, mar.axes)
4141

4242
function productmeasure(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M}
4343
return powermeasure(mar.f.value, axes(mar.data))

src/domains.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ struct Simplex <: CodimOne end
116116

117117
function zeroset(::Simplex)
118118
f(x::AbstractArray{T}) where {T} = sum(x) - one(T)
119-
∇f(x::AbstractArray{T}) where {T} = Fill(one(T), size(x))
119+
∇f(x::AbstractArray{T}) where {T} = fill_with(one(T), size(x))
120120
ZeroSet(f, ∇f)
121121
end
122122

src/standard/stdmeasure.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x)
1313
end
1414

1515
function transport_def::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x)
16-
return Fill(transport_def.parent, μ, only(x)), map(length, ν.axes)...)
16+
return fill_with(transport_def.parent, μ, only(x)), map(length, ν.axes))
1717
end
1818

1919
function transport_def(
@@ -35,7 +35,7 @@ end
3535
# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}):
3636

3737
_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M()
38-
_std_measure(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof
38+
_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof
3939
_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ))
4040

4141
function transport_to(::Type{NU}, μ) where {NU<:StdMeasure}
@@ -90,7 +90,7 @@ end
9090
@inline _offset_cumsum(s, x) = (s,)
9191
@inline _offset_cumsum(s) = ()
9292

93-
function _stdvar_viewranges(μs::Tuple, startidx::Integer)
93+
function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike)
9494
N = map(getdof, μs)
9595
offs = _offset_cumsum(startidx, N...)
9696
map((o, n) -> o:o+n-1, offs, N)

src/static.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
MeasureBase.IntegerLike
3+
4+
Equivalent to `Union{Integer,Static.StaticInt}`.
5+
"""
6+
const IntegerLike = Union{Integer,Static.StaticInt}
7+
8+
"""
9+
MeasureBase.one_to(n::IntegerLike)
10+
11+
Creates a range from one to n.
12+
13+
Returns an instance of `Base.OneTo` or `Static.SOneTo`, depending
14+
on the type of `n`.
15+
"""
16+
@inline one_to(n::Integer) = Base.OneTo(n)
17+
@inline one_to(::Static.StaticInt{N}) where {N} = Static.SOneTo{N}()
18+
19+
_dynamic(x::Number) = dynamic(x)
20+
_dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N)
21+
_dynamic(r::AbstractUnitRange) = minimum(r):maximum(r)
22+
23+
"""
24+
MeasureBase.fill_with(x, sz::NTuple{N,<:IntegerLike}) where N
25+
26+
Creates an array of size `sz` filled with `x`.
27+
28+
Returns an instance of `FillArrays.Fill`.
29+
"""
30+
function fill_with end
31+
32+
@inline function fill_with(x::T, sz::Tuple{Vararg{<:IntegerLike,N}}) where {T,N}
33+
fill_with(x, map(one_to, sz))
34+
end
35+
36+
@inline function fill_with(x::T, axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {T,N}
37+
# While `FillArrays.Fill` (mostly?) works with axes that are static unit
38+
# ranges, some operations that automatic differentiation requires do fail
39+
# on such instances of `Fill` (e.g. `reshape` from dynamic to static size).
40+
# So need to use standard ranges for the axes for now:
41+
dyn_axs = map(_dynamic, axs)
42+
FillArrays.Fill(x, dyn_axs)
43+
end
44+
45+
"""
46+
MeasureBase.maybestatic_length(x)::IntegerLike
47+
48+
Returns the length of `x` as a dynamic or static integer.
49+
"""
50+
maybestatic_length(x) = length(x)
51+
maybestatic_length(x::AbstractUnitRange) = length(x)
52+
function maybestatic_length(::Static.OptionallyStaticUnitRange{StaticInt{A},StaticInt{B}}) where {A,B}
53+
StaticInt{B - A + 1}()
54+
end
55+
56+
"""
57+
MeasureBase.maybestatic_size(x)::Tuple{Vararg{IntegerLike}}
58+
59+
Returns the size of `x` as a tuple of dynamic or static integers.
60+
"""
61+
maybestatic_size(x) = size(x)

test/transport.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ using LogExpFunctions: logit
88
using ChainRulesTestUtils
99

1010
@testset "transport_to" begin
11-
test_rrule(MeasureBase._origin_depth, pushfwd(exp, StdUniform()))
11+
test_rrule(
12+
MeasureBase._origin_depth,
13+
pushfwd(exp, StdUniform()),
14+
output_tangent = static(0),
15+
)
1216

1317
for (f, μ) in [
1418
(logit, StdUniform())

0 commit comments

Comments
 (0)