Skip to content

Make ClipNorm work on GPU Broadcasted #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <[email protected]>"]
version = "0.2.18"
version = "0.2.19"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
25 changes: 24 additions & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{float(typeof(ω))}(ω,
init(o::ClipNorm, x::AbstractArray) = nothing

function apply!(o::ClipNorm, state, x, dx)
nrm = norm(dx, o.p)
nrm = _norm(dx, o.p)
if o.throw && !isfinite(nrm)
throw(DomainError("gradient has $(o.p)-norm $nrm, for array $(summary(x))"))
end
Expand All @@ -620,6 +620,29 @@ function apply!(o::ClipNorm, state, x, dx)
return state, @lazy dx * λ
end

_norm(dx::AbstractArray, p::Real) = norm(dx, p) # LinearAlgebra, CUDA
function _norm(dx::Broadcast.Broadcasted, p::Real)
if p == 2
# This lacks the undeflow/overflow tests of LinearAlgebra's version
sqrt(sum(abs2, dx))
elseif p == 1
float(sum(abs, dx))
elseif p == Inf
float(maximum(abs, dx))
elseif p == 0
cnt = count(!iszero, dx)
T = Base.@default_eltype dx
T <: Number ? convert(float(T), cnt) : cnt
elseif p == -Inf
float(minimum(abs, dx))
else
# This isn't optimally fast but does ensure p::Float64 doesn't promote
tmp = abs.(dx)
q = convert(float(eltype(tmp)), p)
sum(tmp .^ q) ^ (1/q)
end
end

"""
OptimiserChain(opts...)

Expand Down
17 changes: 16 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy
using Base.Broadcast: broadcasted, instantiate, Broadcasted

Random.seed!(1)

Expand Down Expand Up @@ -89,7 +90,8 @@ y2z(x) = x
_, m2 = Optimisers.update(s2, m, (α = ([0.1], nothing), γ = [1,10,100],))
@test only(m.α[1] .- m2.α[1]) ≈ 0.1
@test norm(m.γ .- m2.γ) ≈ 10
@test_throws DomainError Optimisers.update(s2, m, (α = [0.1], γ = [1,10,NaN],))
# This error is thrown by apply! due to NaN input.
@test_throws DomainError Optimisers.update(s2, m, (α = ([0.1], nothing), γ = [1,10,NaN],))

s3 = Optimisers.setup(ClipNorm(5, 1; throw=false), m)
_, m3 = Optimisers.update(s3, m, (α = ([0.1], nothing), γ = [1,10,100],))
Expand Down Expand Up @@ -506,6 +508,19 @@ y2z(x) = x
y = Optimisers.subtract!(x, nothing)
@test y === x
end

@testset "_norm(dx, p) works" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test the interface instead of an internal method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At present there are no GPU tests, and norm(::Broadcasted{..., Array}) works without error. So other tests which didn't fail before do call _norm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#71 was going to add some tests, but did not run into this failure. I haven't checked whether JLArray will in fact see it, but Metal's array type (apple M1) does.

bc = instantiate(broadcasted(+, randn(Float32, 10), randn(Float32, 10)'));
arr = collect(bc)
bc2 = instantiate(broadcasted(+, [1, 0, -3, 4], 0))
arr2 = collect(bc2)
for p in (-Inf, -3, -1, 0, 0.5, 1, 1.5, 2, 3f0, Inf32)
@test Optimisers._norm(bc, p) ≈ norm(arr, p)
@test Optimisers._norm(bc, p) isa Float32
@test Optimisers._norm(bc2, p) ≈ norm(arr2, p)
@test Optimisers._norm(bc2, p) isa Float64
end
end
end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down