Skip to content

Commit 73e8fd1

Browse files
test: add preliminary JET tests
1 parent 99c6ee4 commit 73e8fd1

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ julia = "1.10"
100100
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
101101
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
102102
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
103+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
103104
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
104105
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
105106
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -114,4 +115,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
114115
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
115116

116117
[targets]
117-
test = ["Aqua", "ForwardDiff", "MLStyle", "PartialFunctions", "Pkg", "SafeTestsets", "Serialization", "StableRNGs", "StaticArrays", "Tables", "Test", "UnicodePlots", "Zygote"]
118+
test = ["Aqua", "ForwardDiff", "MLStyle", "PartialFunctions", "Pkg", "SafeTestsets", "Serialization", "StableRNGs", "StaticArrays", "Tables", "Test", "UnicodePlots", "Zygote", "JET"]

test/JET.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using NonlinearSolve
2+
using LinearAlgebra
3+
using ADTypes
4+
using JET
5+
6+
function f(u, p)
7+
L, U = cholesky(p.Σ)
8+
return L \ (u .* u .- p.λ)
9+
end
10+
11+
function minimize=1.0)
12+
ps = (; λ, Σ=hermitianpart(rand(2,2) + 2*I))
13+
u₀ = rand(2)
14+
prob = NonlinearLeastSquaresProblem{false}(f, u₀, ps)
15+
autodiff = AutoForwardDiff(; chunksize=1)
16+
sol = solve(prob, SimpleTrustRegion(; autodiff))
17+
return sol.u
18+
end
19+
20+
@test_opt minimize()

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ end
6060
@time @safetestset "Clocks" begin
6161
include("clock.jl")
6262
end
63+
@time @safetestset "JET" begin
64+
include("JET.jl")
65+
end
6366
end
6467

6568
if !is_APPVEYOR &&

0 commit comments

Comments
 (0)