Skip to content

Commit 6eaf26d

Browse files
authored
Merge pull request #144 from mcabbott/bc_norm
Make ClipNorm work on GPU Broadcasted
2 parents 8a37946 + d73a0ee commit 6eaf26d

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.2.18"
4+
version = "0.2.19"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rules.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{float(typeof(ω))}(ω,
611611
init(o::ClipNorm, x::AbstractArray) = nothing
612612

613613
function apply!(o::ClipNorm, state, x, dx)
614-
nrm = norm(dx, o.p)
614+
nrm = _norm(dx, o.p)
615615
if o.throw && !isfinite(nrm)
616616
throw(DomainError("gradient has $(o.p)-norm $nrm, for array $(summary(x))"))
617617
end
@@ -620,6 +620,29 @@ function apply!(o::ClipNorm, state, x, dx)
620620
return state, @lazy dx * λ
621621
end
622622

623+
_norm(dx::AbstractArray, p::Real) = norm(dx, p) # LinearAlgebra, CUDA
624+
function _norm(dx::Broadcast.Broadcasted, p::Real)
625+
if p == 2
626+
# This lacks the undeflow/overflow tests of LinearAlgebra's version
627+
sqrt(sum(abs2, dx))
628+
elseif p == 1
629+
float(sum(abs, dx))
630+
elseif p == Inf
631+
float(maximum(abs, dx))
632+
elseif p == 0
633+
cnt = count(!iszero, dx)
634+
T = Base.@default_eltype dx
635+
T <: Number ? convert(float(T), cnt) : cnt
636+
elseif p == -Inf
637+
float(minimum(abs, dx))
638+
else
639+
# This isn't optimally fast but does ensure p::Float64 doesn't promote
640+
tmp = abs.(dx)
641+
q = convert(float(eltype(tmp)), p)
642+
sum(tmp .^ q) ^ (1/q)
643+
end
644+
end
645+
623646
"""
624647
OptimiserChain(opts...)
625648

test/runtests.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Optimisers
22
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
33
using LinearAlgebra, Statistics, Test, Random
44
using Optimisers: @.., @lazy
5+
using Base.Broadcast: broadcasted, instantiate, Broadcasted
56

67
Random.seed!(1)
78

@@ -89,7 +90,8 @@ y2z(x) = x
8990
_, m2 = Optimisers.update(s2, m, (α = ([0.1], nothing), γ = [1,10,100],))
9091
@test only(m.α[1] .- m2.α[1]) 0.1
9192
@test norm(m.γ .- m2.γ) 10
92-
@test_throws DomainError Optimisers.update(s2, m, (α = [0.1], γ = [1,10,NaN],))
93+
# This error is thrown by apply! due to NaN input.
94+
@test_throws DomainError Optimisers.update(s2, m, (α = ([0.1], nothing), γ = [1,10,NaN],))
9395

9496
s3 = Optimisers.setup(ClipNorm(5, 1; throw=false), m)
9597
_, m3 = Optimisers.update(s3, m, (α = ([0.1], nothing), γ = [1,10,100],))
@@ -506,6 +508,19 @@ y2z(x) = x
506508
y = Optimisers.subtract!(x, nothing)
507509
@test y === x
508510
end
511+
512+
@testset "_norm(dx, p) works" begin
513+
bc = instantiate(broadcasted(+, randn(Float32, 10), randn(Float32, 10)'));
514+
arr = collect(bc)
515+
bc2 = instantiate(broadcasted(+, [1, 0, -3, 4], 0))
516+
arr2 = collect(bc2)
517+
for p in (-Inf, -3, -1, 0, 0.5, 1, 1.5, 2, 3f0, Inf32)
518+
@test Optimisers._norm(bc, p) norm(arr, p)
519+
@test Optimisers._norm(bc, p) isa Float32
520+
@test Optimisers._norm(bc2, p) norm(arr2, p)
521+
@test Optimisers._norm(bc2, p) isa Float64
522+
end
523+
end
509524
end
510525
@testset verbose=true "Destructure" begin
511526
include("destructure.jl")

0 commit comments

Comments
 (0)