Skip to content

Commit d329806

Browse files
feat: support array symbolics in BatchedInterface
1 parent 0e8e037 commit d329806

File tree

5 files changed

+101
-28
lines changed

5 files changed

+101
-28
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Accessors = "0.1.36"
1515
Aqua = "0.8"
1616
ArrayInterface = "7.9"
1717
MacroTools = "0.5.13"
18+
Pkg = "1"
1819
RuntimeGeneratedFunctions = "0.5"
1920
SafeTestsets = "0.0.1"
2021
StaticArrays = "1.9"
@@ -24,9 +25,10 @@ julia = "1.10"
2425

2526
[extras]
2627
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
28+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2729
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2830
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2931
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3032

3133
[targets]
32-
test = ["Aqua", "Test", "SafeTestsets", "StaticArrays"]
34+
test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"]

src/batched_interface.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,18 @@ function BatchedInterface(syssyms::Tuple...)
5050
symbol_subset = Int[]
5151
symbol_indexes = []
5252
system_isstate = BitVector()
53+
allsyms = []
5354
for sym in syms
5455
if symbolic_type(sym) === NotSymbolic()
5556
error("Only symbolic variables allowed in BatchedInterface.")
5657
end
58+
if symbolic_type(sym) === ArraySymbolic()
59+
append!(allsyms, collect(sym))
60+
else
61+
push!(allsyms, sym)
62+
end
63+
end
64+
for sym in allsyms
5765
if !is_variable(sys, sym) && !is_parameter(sys, sym)
5866
error("Only variables and parameters allowed in BatchedInterface.")
5967
end

test/downstream/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using Symbolics
2+
using SymbolicIndexingInterface
3+
4+
@variables x[1:2] y z
5+
6+
syss = [
7+
SymbolCache([x..., y]),
8+
SymbolCache([x[1], y, z])
9+
]
10+
syms = [
11+
[x, y],
12+
[x[1], y]
13+
]
14+
probs = [
15+
ProblemState(; u = [1.0, 2.0, 3.0]),
16+
ProblemState(; u = [4.0, 5.0, 6.0])
17+
]
18+
19+
bi = BatchedInterface(zip(syss, syms)...)
20+
21+
@test all(isequal.(variable_symbols(bi), [x..., y]))
22+
@test variable_index.((bi,), [x..., y, z]) == [1, 2, 3, nothing]
23+
@test is_variable.((bi,), [x..., y, z]) == Bool[1, 1, 1, 0]
24+
@test associated_systems(bi) == [1, 1, 1]
25+
26+
getter = getu(bi)
27+
@test (@inferred getter(probs...)) == [1.0, 2.0, 3.0]
28+
buf = zeros(3)
29+
@inferred getter(buf, probs...)
30+
@test buf == [1.0, 2.0, 3.0]
31+
32+
setter! = setu(bi)
33+
buf .*= 10
34+
setter!(probs..., buf)
35+
36+
@test state_values(probs[1]) == [10.0, 20.0, 30.0]
37+
@test state_values(probs[2]) == [10.0, 30.0, 6.0]
38+
39+
buf ./= 10
40+
41+
setter!(probs[1], 1, buf)
42+
@test state_values(probs[1]) == [1.0, 2.0, 3.0]

test/runtests.jl

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,52 @@
11
using SymbolicIndexingInterface
22
using SafeTestsets
33
using Test
4+
using Pkg
45

5-
@safetestset "Quality Assurance" begin
6-
@time include("qa.jl")
7-
end
8-
@safetestset "Interface test" begin
9-
@time include("example_test.jl")
10-
end
11-
@safetestset "Trait test" begin
12-
@time include("trait_test.jl")
13-
end
14-
@safetestset "SymbolCache test" begin
15-
@time include("symbol_cache_test.jl")
16-
end
17-
@safetestset "Fallback test" begin
18-
@time include("fallback_test.jl")
19-
end
20-
@safetestset "Parameter indexing test" begin
21-
@time include("parameter_indexing_test.jl")
22-
end
23-
@safetestset "State indexing test" begin
24-
@time include("state_indexing_test.jl")
25-
end
26-
@safetestset "Remake test" begin
27-
@time include("remake_test.jl")
6+
const GROUP = get(ENV, "GROUP", "All")
7+
8+
function activate_downstream_env()
9+
Pkg.activate("downstream")
10+
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
11+
Pkg.instantiate()
2812
end
29-
@safetestset "ProblemState test" begin
30-
@time include("problem_state_test.jl")
13+
14+
if GROUP == "All" || GROUP == "Core"
15+
@safetestset "Quality Assurance" begin
16+
@time include("qa.jl")
17+
end
18+
@safetestset "Interface test" begin
19+
@time include("example_test.jl")
20+
end
21+
@safetestset "Trait test" begin
22+
@time include("trait_test.jl")
23+
end
24+
@safetestset "SymbolCache test" begin
25+
@time include("symbol_cache_test.jl")
26+
end
27+
@safetestset "Fallback test" begin
28+
@time include("fallback_test.jl")
29+
end
30+
@safetestset "Parameter indexing test" begin
31+
@time include("parameter_indexing_test.jl")
32+
end
33+
@safetestset "State indexing test" begin
34+
@time include("state_indexing_test.jl")
35+
end
36+
@safetestset "Remake test" begin
37+
@time include("remake_test.jl")
38+
end
39+
@safetestset "ProblemState test" begin
40+
@time include("problem_state_test.jl")
41+
end
42+
@safetestset "BatchedInterface test" begin
43+
@time include("batched_interface_test.jl")
44+
end
3145
end
32-
@safetestset "BatchedInterface test" begin
33-
@time include("batched_interface_test.jl")
46+
47+
if GROUP == "All" || GROUP == "Downstream"
48+
activate_downstream_env()
49+
@safetestset "BatchedInterface with array symbolics test" begin
50+
@time include("downstream/batchedinterface_arrayvars.jl")
51+
end
3452
end

0 commit comments

Comments
 (0)