@@ -611,7 +611,7 @@ ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{float(typeof(ω))}(ω,
611
611
init (o:: ClipNorm , x:: AbstractArray ) = nothing
612
612
613
613
function apply! (o:: ClipNorm , state, x, dx)
614
- nrm = norm (dx, o. p)
614
+ nrm = _norm (dx, o. p)
615
615
if o. throw && ! isfinite (nrm)
616
616
throw (DomainError (" gradient has $(o. p) -norm $nrm , for array $(summary (x)) " ))
617
617
end
@@ -620,6 +620,48 @@ function apply!(o::ClipNorm, state, x, dx)
620
620
return state, @lazy dx * λ
621
621
end
622
622
623
+ _norm (dx:: AbstractArray , p:: Real ) = norm (dx, p) # LinearAlgebra, CUDA
624
+ function _norm (dx:: Broadcast.Broadcasted , p:: Real )
625
+ if p == 2
626
+ # This lacks the undeflow/overflow tests of LinearAlgebra's version
627
+ sqrt (sum (abs2, dx))
628
+ elseif p == 1
629
+ float (sum (abs, dx))
630
+ elseif p == Inf
631
+ float (maximum (abs, dx))
632
+ elseif p == 0
633
+ cnt = count (! iszero, dx)
634
+ T = Base. @default_eltype dx
635
+ T <: Number ? convert (float (T), cnt) : cnt
636
+ elseif p == - Inf
637
+ float (minimum (abs, dx))
638
+ else
639
+ # This isn't optimally fast but does ensure p::Float64 doesn't promote
640
+ tmp = abs .(dx)
641
+ q = convert (float (eltype (tmp)), p)
642
+ sum (tmp .^ q) ^ (1 / q)
643
+ end
644
+ end
645
+
646
+ #=
647
+
648
+ julia> using Metal
649
+
650
+ julia> using Base.Broadcast: broadcasted, instantiate
651
+
652
+ julia> bc = instantiate(broadcasted(+, MtlArray(rand(Float32, 3)), 1));
653
+
654
+ julia> norm(bc)
655
+ ┌ Warning: Performing scalar indexing
656
+
657
+ └ @ Metal ~/.julia/packages/Metal/TtPHW/src/compiler/compilation.jl:77
658
+ ERROR: NSError: Undefined symbols:
659
+ llvm.maximum.f32, referenced from: _Z24partial_mapreduce_device8identity3max7Float323ValILi1024EES2_I22CartesianIndices__3___ES2_I22CartesianIndices__1___ES2_ILi1EES2_ILi1EES2_ILitrueEE14MtlDeviceArrayIS1_Li2ELi1EE11BroadcastedI13MtlArrayStyleILi1EE5TupleI5OneToI5Int64EE4normS6_IS4_IS5_ILi1EES6_IS7_IS8_EE1_S6_IS3_IS1_Li1ELi1EES8_EEEE
660
+
661
+ julia> Metal.allowscalar(false)
662
+
663
+ =#
664
+
623
665
"""
624
666
OptimiserChain(opts...)
625
667
0 commit comments