Skip to content

Commit 47071c2

Browse files
committed
fix levin for AD
1 parent 077e0ef commit 47071c2

File tree

3 files changed

+57
-48
lines changed

3 files changed

+57
-48
lines changed

src/AiryFunctions/cairy.jl

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -398,14 +398,16 @@ end
398398
a = @fastmath inv(4*xsqrx)
399399
a2 = 4 * xsqrx
400400

401-
s = zero(typeof(x))
402-
l = @ntuple $N i -> begin
403-
s += t
401+
s_0 = zero(typeof(x))
402+
@nexprs $N i -> begin
403+
s_{i} = s_{i-1} + t
404404
t *= -a * (3 * (i - 5//6) * (i - 1//6) / i)
405405
t2 *= -a2 * (i / (3 * (i - 5//6) * (i - 1//6)))
406-
Vec{4, T}((reim(s * t2)..., reim(t2)...))
406+
w_{i} = t2
407407
end
408-
return levin_transform(l) / (T(π)^(3//2) * sqrt(xsqr))
408+
sequence = @ntuple $N i -> s_{i}
409+
weights = @ntuple $N i -> w_{i}
410+
return levin_transform(sequence, weights) / (T(π)^(3//2) * sqrt(xsqr))
409411
end
410412
)
411413
end
@@ -420,14 +422,16 @@ end
420422
a = @fastmath inv(4*xsqrx)
421423
a2 = 4 * xsqrx
422424

423-
s = zero(typeof(x))
424-
l = @ntuple $N i -> begin
425-
s += t
425+
s_0 = zero(typeof(x))
426+
@nexprs $N i -> begin
427+
s_{i} = s_{i-1} + t
426428
t *= -a * (3 * (i - 7//6) * (i + 1//6) / i)
427429
t2 *= -a2 * (i / (3 * (i - 7//6) * (i + 1//6)))
428-
Vec{4, T}((reim(s * t2)..., reim(t2)...))
430+
w_{i} = t2
429431
end
430-
return levin_transform(l) * sqrt(xsqr) / T(π)^(3//2)
432+
sequence = @ntuple $N i -> s_{i}
433+
weights = @ntuple $N i -> w_{i}
434+
return levin_transform(sequence, weights) * sqrt(xsqr) / T(π)^(3//2)
431435
end
432436
)
433437
end
@@ -444,24 +448,26 @@ end
444448
a = inv(4*xsqrx)
445449
a2 = 4 * xsqrx
446450

447-
s = zero(typeof(x))
448-
s2 = zero(typeof(x))
451+
s_0 = zero(typeof(x))
452+
p_0 = zero(typeof(x))
449453
m = 1
450454
@nexprs $N i -> begin
451-
s += t
452-
s2 += t * m
455+
s_{i} = s_{i-1} + t
456+
p_{i} = p_{i-1} + t * m
453457
t *= -a * (3 * (i - 5//6) * (i - 1//6) / i)
454458
t2 *= -a2 * (i / (3 * (i - 5//6) * (i - 1//6)))
455-
l_{i} = Vec{4, T}((reim(s * t2)..., reim(t2)...))
456-
w_{i} = Vec{4, T}((reim(s2 * t2 * m)..., reim(t2 * m)...))
459+
w_{i} = t2
457460
m *= -1
458461
end
459462

460-
l = @ntuple $N i -> l_{i}
461-
w = @ntuple $N i -> w_{i}
463+
sequence1 = @ntuple $N i -> s_{i}
464+
sequence2 = @ntuple $N i -> p_{i}
465+
weights = @ntuple $N i -> w_{i}
462466

463-
e = exp(-2/3 * x * sqrt(x))
464-
return (e*im*levin_transform(l) + 2*levin_transform(w)/e) / (sqrt(T(π)^3) * sqrt(xsqr))
467+
e = exp(-T(2/3) * x * sqrt(x))
468+
l1 = levin_transform(sequence1, weights)
469+
l2 = levin_transform(sequence2, (@ntuple $N i -> weights[i] * (iseven(i) ? 1 : -1)))
470+
return (e * im * l1 + 2 * l2 / e) / (sqrt(T(π)^3) * sqrt(xsqr))
465471
end
466472
)
467473
end
@@ -477,24 +483,26 @@ end
477483
a = inv(4*xsqrx)
478484
a2 = 4 * xsqrx
479485

480-
s = zero(typeof(x))
481-
s2 = zero(typeof(x))
486+
s_0 = zero(typeof(x))
487+
p_0 = zero(typeof(x))
482488
m = 1
483489
@nexprs $N i -> begin
484-
s += t
485-
s2 += t * m
490+
s_{i} = s_{i-1} + t
491+
p_{i} = p_{i-1} + t * m
486492
t *= -a * (3 * (i - 7//6) * (i + 1//6) / i)
487493
t2 *= -a2 * (i / (3 * (i - 7//6) * (i + 1//6)))
488-
l_{i} = Vec{4, T}((reim(s * t2)..., reim(t2)...))
489-
w_{i} = Vec{4, T}((reim(s2 * t2 * m)..., reim(t2 * m)...))
494+
w_{i} = t2
490495
m *= -1
491496
end
492497

493-
l = @ntuple $N i -> l_{i}
494-
w = @ntuple $N i -> w_{i}
498+
sequence1 = @ntuple $N i -> s_{i}
499+
sequence2 = @ntuple $N i -> p_{i}
500+
weights = @ntuple $N i -> w_{i}
495501

496-
e = exp(-2/3 * x * sqrt(x))
497-
return -(e*im*levin_transform(l) - 2*levin_transform(w)/e) * sqrt(xsqr) / (sqrt(T(π)^3))
502+
l1 = levin_transform(sequence1, weights)
503+
l2 = levin_transform(sequence2, (@ntuple $N i -> weights[i] * (iseven(i) ? 1 : -1)))
504+
e = exp(-T(2/3) * x * sqrt(x))
505+
return -(e * im * l1 - 2 * l2 / e) * sqrt(xsqr) / (sqrt(T(π)^3))
498506
end
499507
)
500508
end

src/BesselFunctions/besselk.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -578,36 +578,39 @@ end
578578
@generated function besselkx_levin(v, x::T, ::Val{N}) where {T <: FloatTypes, N}
579579
:(
580580
begin
581-
s = zero(T)
581+
s_0 = zero(T)
582582
t = one(T)
583-
l = @ntuple $N i -> begin
584-
s += t
583+
@nexprs $N i -> begin
584+
s_{i} = s_{i-1} + t
585585
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i)
586-
invterm = inv(t)
587-
Vec{2, T}((s * invterm, invterm))
586+
w_{i} = 1 / t
588587
end
589-
return levin_transform(l) * sqrt/ 2x)
588+
sequence = @ntuple $N i -> s_{i}
589+
weights = @ntuple $N i -> w_{i}
590+
return levin_transform(sequence, weights) * sqrt/ 2x)
590591
end
591592
)
592593
end
593594

594595
@generated function besselkx_levin(v, x::Complex{T}, ::Val{N}) where {T <: FloatTypes, N}
595596
:(
596597
begin
597-
s = zero(T)
598+
s_0 = zero(T)
598599
t = one(typeof(x))
599600
t2 = t
600601
a = @fastmath inv(8*x)
601602
a2 = 8*x
602603

603-
l = @ntuple $N i -> begin
604-
s += t
604+
@nexprs $N i -> begin
605+
s_{i} = s_{i-1} + t
605606
b = (4*v^2 - (2i - 1)^2) / i
606607
t *= a * b
607608
t2 *= a2 / b
608-
Vec{4, T}((reim(s * t2)..., reim(t2)...))
609+
w_{i} = t2
609610
end
610-
return levin_transform(l) * sqrt/ 2x)
611+
sequence = @ntuple $N i -> s_{i}
612+
weights = @ntuple $N i -> w_{i}
613+
return levin_transform(sequence, weights) * sqrt/ 2x)
611614
end
612615
)
613616
end

src/Math/Math.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,28 +131,26 @@ end
131131
#@inline levin_scale(B::T, n, k) where T = -(B + n) * (B + n + k)^(k - one(T)) / (B + n + k + one(T))^k
132132
@inline levin_scale(B::T, n, k) where T = -(B + n + k) * (B + n + k - 1) / ((B + n + 2k) * (B + n + 2k - 1))
133133

134-
# implementation for real numbers
135-
@inline @generated function levin_transform(s::NTuple{N, Vec{2, T}}) where {N, T <: FloatTypes}
134+
@inline @generated function levin_transform(s::NTuple{N, T}, w::NTuple{N, T}) where {N, T <: FloatTypes}
136135
len = N - 1
137136
:(
138137
begin
139-
@nexprs $N i -> a_{i} = s[i]
138+
@nexprs $N i -> a_{i} = Vec{2, T}((s[i] * w[i], w[i]))
140139
@nexprs $len k -> (@nexprs ($len-k) i -> a_{i} = fmadd(a_{i}, levin_scale(one(T), i, k-1), a_{i+1}))
141140
return (a_1[1] / a_1[2])
142141
end
143142
)
144143
end
145144

146-
# implementation for complex numbers
147-
@inline @generated function levin_transform(s::NTuple{N, Vec{4, T}}) where {N, T <: FloatTypes}
145+
@inline @generated function levin_transform(s::NTuple{N, Complex{T}}, w::NTuple{N, Complex{T}}) where {N, T <: FloatTypes}
148146
len = N - 1
149147
:(
150148
begin
151-
@nexprs $N i -> a_{i} = s[i]
149+
@nexprs $N i -> a_{i} = Vec{4, T}((reim(s[i] * w[i])..., reim(w[i])...))
152150
@nexprs $len k -> (@nexprs ($len-k) i -> a_{i} = fmadd(a_{i}, levin_scale(one(T), i, k-1), a_{i+1}))
153151
return (complex(a_1[1], a_1[2]) / complex(a_1[3], a_1[4]))
154152
end
155153
)
156154
end
157155

158-
end
156+
end

0 commit comments

Comments
 (0)