Skip to content

Commit c3a02b5

Browse files
committed
add _norm
1 parent 8a37946 commit c3a02b5

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.2.18"
4+
version = "0.2.19"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rules.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{float(typeof(ω))}(ω,
611611
init(o::ClipNorm, x::AbstractArray) = nothing
612612

613613
function apply!(o::ClipNorm, state, x, dx)
614-
nrm = norm(dx, o.p)
614+
nrm = _norm(dx, o.p)
615615
if o.throw && !isfinite(nrm)
616616
throw(DomainError("gradient has $(o.p)-norm $nrm, for array $(summary(x))"))
617617
end
@@ -620,6 +620,48 @@ function apply!(o::ClipNorm, state, x, dx)
620620
return state, @lazy dx * λ
621621
end
622622

623+
_norm(dx::AbstractArray, p::Real) = norm(dx, p) # LinearAlgebra, CUDA
624+
function _norm(dx::Broadcast.Broadcasted, p::Real)
625+
if p == 2
626+
# This lacks the undeflow/overflow tests of LinearAlgebra's version
627+
sqrt(sum(abs2, dx))
628+
elseif p == 1
629+
float(sum(abs, dx))
630+
elseif p == Inf
631+
float(maximum(abs, dx))
632+
elseif p == 0
633+
cnt = count(!iszero, dx)
634+
T = Base.@default_eltype dx
635+
T <: Number ? convert(float(T), cnt) : cnt
636+
elseif p == -Inf
637+
float(minimum(abs, dx))
638+
else
639+
# This isn't optimally fast but does ensure p::Float64 doesn't promote
640+
tmp = abs.(dx)
641+
q = convert(float(eltype(tmp)), p)
642+
sum(tmp .^ q) ^ (1/q)
643+
end
644+
end
645+
646+
#=
647+
648+
julia> using Metal
649+
650+
julia> using Base.Broadcast: broadcasted, instantiate
651+
652+
julia> bc = instantiate(broadcasted(+, MtlArray(rand(Float32, 3)), 1));
653+
654+
julia> norm(bc)
655+
┌ Warning: Performing scalar indexing
656+
657+
└ @ Metal ~/.julia/packages/Metal/TtPHW/src/compiler/compilation.jl:77
658+
ERROR: NSError: Undefined symbols:
659+
llvm.maximum.f32, referenced from: _Z24partial_mapreduce_device8identity3max7Float323ValILi1024EES2_I22CartesianIndices__3___ES2_I22CartesianIndices__1___ES2_ILi1EES2_ILi1EES2_ILitrueEE14MtlDeviceArrayIS1_Li2ELi1EE11BroadcastedI13MtlArrayStyleILi1EE5TupleI5OneToI5Int64EE4normS6_IS4_IS5_ILi1EES6_IS7_IS8_EE1_S6_IS3_IS1_Li1ELi1EES8_EEEE
660+
661+
julia> Metal.allowscalar(false)
662+
663+
=#
664+
623665
"""
624666
OptimiserChain(opts...)
625667

test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Optimisers
22
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
33
using LinearAlgebra, Statistics, Test, Random
44
using Optimisers: @.., @lazy
5+
using Base.Broadcast: broadcasted, instantiate, Broadcasted
56

67
Random.seed!(1)
78

@@ -506,6 +507,19 @@ y2z(x) = x
506507
y = Optimisers.subtract!(x, nothing)
507508
@test y === x
508509
end
510+
511+
@testset "_norm(dx, p) works" begin
512+
bc = instantiate(broadcasted(+, randn(Float32, 10), randn(Float32, 10)'));
513+
arr = collect(bc)
514+
bc2 = instantiate(broadcasted(+, [1, 0, -3, 4], 0))
515+
arr2 = collect(bc2)
516+
for p in (-Inf, -3, -1, 0, 0.5, 1, 1.5, 2, 3f0, Inf32)
517+
@test Optimisers._norm(bc, p) norm(arr, p)
518+
@test Optimisers._norm(bc, p) isa Float32
519+
@test Optimisers._norm(bc2, p) norm(arr2, p)
520+
@test Optimisers._norm(bc2, p) isa Float64
521+
end
522+
end
509523
end
510524
@testset verbose=true "Destructure" begin
511525
include("destructure.jl")

0 commit comments

Comments
 (0)