Skip to content

Commit 5d2edc1

Browse files
authored
Avoid non-lazy string interpolation and replace @assert with ArgumentError (#779)
* Avoid non-lazy string interpolation and replace `@assert` with `ArgumentError` * Add tests * Fix tests
1 parent 463e830 commit 5d2edc1

File tree

10 files changed

+49
-33
lines changed

10 files changed

+49
-33
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
- 'min'
2727
- 'lts'
2828
- '1'
29+
- 'pre'
2930
os:
3031
- ubuntu-latest
3132
- windows-latest

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ DiffResults = "1.1"
2727
DiffRules = "1.4"
2828
DiffTests = "0.1"
2929
IrrationalConstants = "0.1, 0.2"
30+
JET = "0.9, 0.10"
3031
LogExpFunctions = "0.3"
3132
NaNMath = "1"
3233
Preferences = "1"
@@ -39,9 +40,10 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
3940
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
4041
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
4142
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
43+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4244
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4345
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4446
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4547

4648
[targets]
47-
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"]
49+
test = ["Calculus", "DiffTests", "IrrationalConstants", "JET", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"]

src/config.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct InvalidTagException{E,O} <: Exception
2929
end
3030

3131
Base.showerror(io::IO, e::InvalidTagException{E,O}) where {E,O} =
32-
print(io, "Invalid Tag object:\n Expected $E,\n Observed $O.")
32+
print(io, "Invalid Tag object:\n Expected ", E, ",\n Observed ", O, ".")
3333

3434
checktag(::Type{Tag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT,VT,F,V} =
3535
throw(InvalidTagException{Tag{F,V},Tag{FT,VT}}())

src/dual.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ struct DualMismatchError{A,B} <: Exception
3535
end
3636

3737
Base.showerror(io::IO, e::DualMismatchError{A,B}) where {A,B} =
38-
print(io, "Cannot determine ordering of Dual tags $(e.a) and $(e.b)")
38+
print(io, "Cannot determine ordering of Dual tags ", e.a, " and ", e.b)
3939

4040
@noinline function throw_cannot_dual(V::Type)
41-
throw(ArgumentError("Cannot create a dual over scalar type $V." *
42-
" If the type behaves as a scalar, define ForwardDiff.can_dual(::Type{$V}) = true."))
41+
throw(ArgumentError(lazy"Cannot create a dual over scalar type $V. If the type behaves as a scalar, define ForwardDiff.can_dual(::Type{$V}) = true."))
4342
end
4443

4544
"""

src/gradient.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ end
113113

114114
function chunk_mode_gradient_expr(result_definition::Expr)
115115
return quote
116-
@assert structural_length(x) >= N "chunk size cannot be greater than ForwardDiff.structural_length(x) ($(N) > $(structural_length(x)))"
116+
if structural_length(x) < N
117+
throw(ArgumentError(lazy"chunk size cannot be greater than ForwardDiff.structural_length(x) ($(N) > $(structural_length(x)))"))
118+
end
117119

118120
# precalculate loop bounds
119121
xlen = structural_length(x)

src/jacobian.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ const JACOBIAN_ERROR = DimensionMismatch("jacobian(f, x) expects that f(x) is an
169169
function jacobian_chunk_mode_expr(work_array_definition::Expr, compute_ydual::Expr,
170170
result_definition::Expr, y_definition::Expr)
171171
return quote
172-
@assert structural_length(x) >= N "chunk size cannot be greater than ForwardDiff.structural_length(x) ($(N) > $(structural_length(x)))"
172+
if structural_length(x) < N
173+
throw(ArgumentError(lazy"chunk size cannot be greater than ForwardDiff.structural_length(x) ($(N) > $(structural_length(x)))"))
174+
end
173175

174176
# precalculate loop bounds
175177
xlen = structural_length(x)

test/AllocationsTest.jl

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,19 @@ convert_test_574() = convert(ForwardDiff.Dual{Nothing,ForwardDiff.Dual{Nothing,F
1313
seeds = cfg.seeds
1414
seed = cfg.seeds[1]
1515

16-
alloc = @allocated ForwardDiff.seed!(duals, x, seeds)
17-
alloc = @allocated ForwardDiff.seed!(duals, x, seeds)
18-
@test alloc == 0
19-
20-
alloc = @allocated ForwardDiff.seed!(duals, x, seed)
21-
alloc = @allocated ForwardDiff.seed!(duals, x, seed)
22-
@test alloc == 0
23-
24-
index = 1
25-
alloc = @allocated ForwardDiff.seed!(duals, x, index, seeds)
26-
alloc = @allocated ForwardDiff.seed!(duals, x, index, seeds)
27-
@test alloc == 0
28-
29-
index = 1
30-
alloc = @allocated ForwardDiff.seed!(duals, x, index, seed)
31-
alloc = @allocated ForwardDiff.seed!(duals, x, index, seed)
32-
@test alloc == 0
33-
34-
alloc = @allocated convert_test_574()
35-
alloc = @allocated convert_test_574()
36-
@test alloc == 0
37-
16+
allocs_seed!(args...) = @allocated ForwardDiff.seed!(args...)
17+
allocs_seed!(duals, x, seeds)
18+
@test iszero(allocs_seed!(duals, x, seeds))
19+
allocs_seed!(duals, x, seed)
20+
@test iszero(allocs_seed!(duals, x, seed))
21+
allocs_seed!(duals, x, 1, seeds)
22+
@test iszero(allocs_seed!(duals, x, 1, seeds))
23+
allocs_seed!(duals, x, 1, seed)
24+
@test iszero(allocs_seed!(duals, x, 1, seed))
25+
26+
allocs_convert_test_574() = @allocated convert_test_574()
27+
allocs_convert_test_574()
28+
@test iszero(allocs_convert_test_574())
3829
end
3930

4031
end

test/DualTest.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,7 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
507507
@eval begin
508508
x = rand() + $modifier
509509
dx = @inferred $M.$f(Dual{TestTag}(x, one(x)))
510-
actualval = $M.$f(x)
511-
@assert actualval isa Real || actualval isa Complex
510+
actualval = $M.$f(x)::Union{Real,Complex}
512511
if actualval isa Real
513512
@test dx isa Dual{TestTag}
514513
@test value(dx) == actualval
@@ -536,8 +535,7 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
536535
dy = @inferred $M.$f(x, Dual{TestTag}(y, one(y)))
537536
actualdx = $(derivs[1])
538537
actualdy = $(derivs[2])
539-
actualval = $M.$f(x, y)
540-
@assert actualval isa Real || actualval isa Complex
538+
actualval = $M.$f(x, y)::Union{Real,Complex}
541539
if actualval isa Real
542540
@test dx isa Dual{TestTag}
543541
@test dy isa Dual{TestTag}

test/QATest.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
module QATest
2+
3+
using ForwardDiff
4+
using Test
5+
6+
using JET: @test_opt
7+
8+
@testset "JET" begin
9+
# issue #778
10+
@test_opt ForwardDiff.derivative(identity, 1.0)
11+
@test_opt ForwardDiff.gradient(only, [1.0], ForwardDiff.GradientConfig(only, [1.0], ForwardDiff.Chunk{1}()))
12+
@test_opt ForwardDiff.jacobian(identity, [1.0], ForwardDiff.JacobianConfig(identity, [1.0], ForwardDiff.Chunk{1}()))
13+
@test_opt ForwardDiff.hessian(only, [1.0], ForwardDiff.HessianConfig(only, [1.0], ForwardDiff.Chunk{1}()))
14+
end
15+
16+
end # module

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,10 @@ Random.seed!(SEED)
5151
t = @elapsed include("AllocationsTest.jl")
5252
println("##### done (took $t seconds).")
5353
end
54+
@testset "QA" begin
55+
println("##### QA testing...")
56+
t = @elapsed include("QATest.jl")
57+
println("##### done (took ", t, " seconds).")
58+
end
5459
println("##### Running all ForwardDiff tests took $(time() - t0) seconds.")
5560
end

0 commit comments

Comments
 (0)