@@ -455,6 +455,134 @@ function augmented_f(
455455 nothing
456456end
457457
458+ function augmented_f(
459+ u:: Any ,
460+ p:: Any ,
461+ :: Any ,
462+ icnf:: ICNF{T, <:LuxVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J} ,
463+ mode:: TrainMode ,
464+ nn:: LuxCore.AbstractLuxLayer ,
465+ st:: NamedTuple ,
466+ ϵ:: AbstractMatrix{T} ,
467+ ) where {T <: AbstractFloat , COND, AUGMENTED, STEER, NORM_Z, NORM_J}
468+ n_aug = n_augment(icnf, mode)
469+ snn = Lux. StatefulLuxLayer{true }(nn, p, st)
470+ z = u[begin : (end - n_aug - 1 ), :]
471+ ż = snn(z)
472+ ϵJ = Lux. vector_jacobian_product(snn, icnf. compute_mode. adback, z, ϵ)
473+ l̇ = - sum(ϵJ .* ϵ; dims = 1 )
474+ Ė = transpose(if NORM_Z
475+ LinearAlgebra. norm.(eachcol(ż))
476+ else
477+ zrs_Ė = similar(ż, size(ż, 2 ))
478+ ChainRulesCore. @ignore_derivatives fill!(zrs_Ė, zero(T))
479+ zrs_Ė
480+ end )
481+ ṅ = transpose(if NORM_J
482+ LinearAlgebra. norm.(eachcol(ϵJ))
483+ else
484+ zrs_ṅ = similar(ż, size(ż, 2 ))
485+ ChainRulesCore. @ignore_derivatives fill!(zrs_ṅ, zero(T))
486+ zrs_ṅ
487+ end )
488+ vcat(ż, l̇, Ė, ṅ)
489+ end
490+
491+ function augmented_f(
492+ du:: Any ,
493+ u:: Any ,
494+ p:: Any ,
495+ :: Any ,
496+ icnf:: ICNF{T, <:LuxVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J} ,
497+ mode:: TrainMode ,
498+ nn:: LuxCore.AbstractLuxLayer ,
499+ st:: NamedTuple ,
500+ ϵ:: AbstractMatrix{T} ,
501+ ) where {T <: AbstractFloat , COND, AUGMENTED, STEER, NORM_Z, NORM_J}
502+ n_aug = n_augment(icnf, mode)
503+ snn = Lux. StatefulLuxLayer{true }(nn, p, st)
504+ z = u[begin : (end - n_aug - 1 ), :]
505+ ż = snn(z)
506+ ϵJ = Lux. vector_jacobian_product(snn, icnf. compute_mode. adback, z, ϵ)
507+ du[begin : (end - n_aug - 1 ), :] .= ż
508+ du[(end - n_aug), :] .= - vec(sum(ϵJ .* ϵ; dims = 1 ))
509+ du[(end - n_aug + 1 ), :] .= if NORM_Z
510+ LinearAlgebra. norm.(eachcol(ż))
511+ else
512+ zero(T)
513+ end
514+ du[(end - n_aug + 2 ), :] .= if NORM_J
515+ LinearAlgebra. norm.(eachcol(ϵJ))
516+ else
517+ zero(T)
518+ end
519+ nothing
520+ end
521+
522+ function augmented_f(
523+ u:: Any ,
524+ p:: Any ,
525+ :: Any ,
526+ icnf:: ICNF{T, <:LuxJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J} ,
527+ mode:: TrainMode ,
528+ nn:: LuxCore.AbstractLuxLayer ,
529+ st:: NamedTuple ,
530+ ϵ:: AbstractMatrix{T} ,
531+ ) where {T <: AbstractFloat , COND, AUGMENTED, STEER, NORM_Z, NORM_J}
532+ n_aug = n_augment(icnf, mode)
533+ snn = Lux. StatefulLuxLayer{true }(nn, p, st)
534+ z = u[begin : (end - n_aug - 1 ), :]
535+ ż = snn(z)
536+ Jϵ = Lux. jacobian_vector_product(snn, icnf. compute_mode. adback, z, ϵ)
537+ l̇ = - sum(ϵ .* Jϵ; dims = 1 )
538+ Ė = transpose(if NORM_Z
539+ LinearAlgebra. norm.(eachcol(ż))
540+ else
541+ zrs_Ė = similar(ż, size(ż, 2 ))
542+ ChainRulesCore. @ignore_derivatives fill!(zrs_Ė, zero(T))
543+ zrs_Ė
544+ end )
545+ ṅ = transpose(if NORM_J
546+ LinearAlgebra. norm.(eachcol(Jϵ))
547+ else
548+ zrs_ṅ = similar(ż, size(ż, 2 ))
549+ ChainRulesCore. @ignore_derivatives fill!(zrs_ṅ, zero(T))
550+ zrs_ṅ
551+ end )
552+ vcat(ż, l̇, Ė, ṅ)
553+ end
554+
555+ function augmented_f(
556+ du:: Any ,
557+ u:: Any ,
558+ p:: Any ,
559+ :: Any ,
560+ icnf:: ICNF{T, <:LuxJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J} ,
561+ mode:: TrainMode ,
562+ nn:: LuxCore.AbstractLuxLayer ,
563+ st:: NamedTuple ,
564+ ϵ:: AbstractMatrix{T} ,
565+ ) where {T <: AbstractFloat , COND, AUGMENTED, STEER, NORM_Z, NORM_J}
566+ n_aug = n_augment(icnf, mode)
567+ snn = Lux. StatefulLuxLayer{true }(nn, p, st)
568+ z = u[begin : (end - n_aug - 1 ), :]
569+ ż = snn(z)
570+ Jϵ = Lux. jacobian_vector_product(snn, icnf. compute_mode. adback, z, ϵ)
571+ du[begin : (end - n_aug - 1 ), :] .= ż
572+ du[(end - n_aug), :] .= - vec(sum(ϵ .* Jϵ; dims = 1 ))
573+ du[(end - n_aug + 1 ), :] .= if NORM_Z
574+ LinearAlgebra. norm.(eachcol(ż))
575+ else
576+ zero(T)
577+ end
578+ du[(end - n_aug + 2 ), :] .= if NORM_J
579+ LinearAlgebra. norm.(eachcol(Jϵ))
580+ else
581+ zero(T)
582+ end
583+ nothing
584+ end
585+
458586@inline function loss(
459587 icnf:: ICNF{<:AbstractFloat, <:VectorMode} ,
460588 mode:: TrainMode ,
0 commit comments