diff --git a/src/utils.jl b/src/utils.jl index 8c34c316..52fb856f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,15 +3,29 @@ unrolled_foreach!(f, ::Tuple{}) = nothing """ ```julia -recursivecopy(a::Union{AbstractArray{T, N}, AbstractVectorOfArray{T, N}}) +recursivecopy(a) ``` A recursive `copy` function. Acts like a `deepcopy` on arrays of arrays, but -like `copy` on arrays of scalars. -""" -function recursivecopy(a) - deepcopy(a) +like `copy` on arrays of scalars. For struct types, recursively copies each +field, creating new instances while preserving the struct type. + +## Examples + +```julia +# Basic array copying +arr = [[1, 2], [3, 4]] +copied = recursivecopy(arr) # New arrays at each level + +# Struct copying +struct MyStruct + data::Vector{Float64} + metadata::String end +original = MyStruct([1.0, 2.0], "test") +copied = recursivecopy(original) # New struct with new vector +``` +""" function recursivecopy(a::Union{StaticArraysCore.SVector, StaticArraysCore.SMatrix, StaticArraysCore.SArray, Number}) copy(a) @@ -35,6 +49,42 @@ function recursivecopy(a::AbstractVectorOfArray) return b end +function _is_basic_julia_type(T) + # Check if this is a built-in Julia type that we should not handle as a user struct + # We check the module to identify Core/Base types vs user-defined types + mod = Base.parentmodule(T) + return T <: AbstractString || T <: Number || T <: Symbol || T <: Tuple || + T <: UnitRange || T <: StepRange || T <: Regex || T <: Function || + T === Nothing || T === Missing || + mod === Core || mod === Base +end + +function recursivecopy(s::T) where {T} + # Only handle user-defined immutable structs. Many basic Julia types (String, Symbol, + # Tuple, etc.) are technically structs but should use copy() or return as-is. + if Base.isstructtype(T) && !_is_basic_julia_type(T) + if Base.ismutabletype(T) + error("recursivecopy for mutable structs is not currently implemented. Use deepcopy instead.") + else + # Handle immutable structs only + field_values = ntuple(fieldcount(T)) do i + field_value = getfield(s, i) + recursivecopy(field_value) + end + return T(field_values...) + end + elseif _is_basic_julia_type(T) + # For basic Julia types, use copy if available, otherwise return as-is (for immutable types) + if hasmethod(copy, Tuple{T}) + return copy(s) + else + return s # Immutable basic types like Symbol, Nothing, Missing don't need copying + end + else + deepcopy(s) + end +end + """ ```julia recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T, N}) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 406fd3f4..da179145 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -2,6 +2,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/test/downstream/ode_solution_copy_test.jl b/test/downstream/ode_solution_copy_test.jl new file mode 100644 index 00000000..f988076b --- /dev/null +++ b/test/downstream/ode_solution_copy_test.jl @@ -0,0 +1,112 @@ +using OrdinaryDiffEq, RecursiveArrayTools, Test + +@testset "ODE Solution recursivecopy tests" begin + + @testset "Basic ODE solution copying" begin + # Define a simple ODE system + function simple_ode!(du, u, p, t) + du[1] = -0.5 * u[1] + du[2] = 0.3 * u[2] + end + + u0 = [1.0, 2.0] + tspan = (0.0, 2.0) + prob = ODEProblem(simple_ode!, u0, tspan) + sol = solve(prob, Tsit5(), saveat=0.5) + + # Test that we can copy the solution + copied_sol = recursivecopy(sol) + + # Verify the solution structure is preserved + @test typeof(copied_sol) == typeof(sol) + @test copied_sol.t == sol.t + @test copied_sol.u == sol.u + @test copied_sol.retcode == sol.retcode + + # Verify that arrays are independent copies + @test copied_sol !== sol + @test copied_sol.u !== sol.u + @test copied_sol.t !== sol.t + + # Test that modifying one doesn't affect the other + if length(copied_sol.u) > 0 + original_value = sol.u[1][1] + copied_sol.u[1][1] = 999.0 + @test sol.u[1][1] == original_value # Original should be unchanged + end + end + + @testset "ArrayPartition ODE solution copying" begin + # Use the Lorenz system from the existing tests + function lorenz!(du, u, p, t) + du.x[1][1] = 10.0 * (u.x[2][1] - u.x[1][1]) + du.x[1][2] = u.x[1][1] * (28.0 - u.x[2][1]) - u.x[1][2] + du.x[2][1] = u.x[1][1] * u.x[1][2] - (8/3) * u.x[2][1] + end + + u0 = ArrayPartition([1.0, 0.0], [0.0]) + tspan = (0.0, 1.0) + prob = ODEProblem(lorenz!, u0, tspan) + sol = solve(prob, Tsit5(), saveat=0.2) + + # Test that we can copy the ArrayPartition-based solution + copied_sol = recursivecopy(sol) + + # Verify solution structure + @test typeof(copied_sol) == typeof(sol) + @test copied_sol.t == sol.t + @test length(copied_sol.u) == length(sol.u) + + # Verify ArrayPartition structure is preserved + for i in eachindex(sol.u) + @test copied_sol.u[i] isa ArrayPartition + @test copied_sol.u[i].x[1] == sol.u[i].x[1] + @test copied_sol.u[i].x[2] == sol.u[i].x[2] + + # Verify independence + @test copied_sol.u[i] !== sol.u[i] + @test copied_sol.u[i].x[1] !== sol.u[i].x[1] + @test copied_sol.u[i].x[2] !== sol.u[i].x[2] + end + end + + @testset "Struct-based parameter copying in ODE" begin + # Create an ODE with struct-based parameters to test our struct copying + struct ODEParams + decay_rate::Float64 + growth_rate::Float64 + coefficients::Vector{Float64} + end + + function parametric_ode!(du, u, params::ODEParams, t) + du[1] = -params.decay_rate * u[1] + params.coefficients[1] + du[2] = params.growth_rate * u[2] + params.coefficients[2] + end + + params = ODEParams(0.5, 0.3, [0.1, 0.2]) + u0 = [1.0, 2.0] + tspan = (0.0, 1.0) + prob = ODEProblem(parametric_ode!, u0, tspan, params) + sol = solve(prob, Tsit5(), saveat=0.5) + + # Test copying the solution (which contains the struct parameters) + copied_sol = recursivecopy(sol) + + # Verify parameter struct is copied correctly + @test copied_sol.prob.p isa ODEParams + @test typeof(copied_sol) == typeof(sol) + @test copied_sol.prob.p.decay_rate == sol.prob.p.decay_rate + @test copied_sol.prob.p.growth_rate == sol.prob.p.growth_rate + @test copied_sol.prob.p.coefficients == sol.prob.p.coefficients + + # Test that the main solution arrays are independent (most important for users) + original_u = sol.u[1][1] + copied_sol.u[1][1] = 888.0 + @test sol.u[1][1] == original_u # Solution data should be independent + + # Note: ODE solution internal structures may have optimized sharing + # The key success is that recursivecopy works and solution data is independent + end + + println("All ODE solution recursivecopy tests completed successfully!") +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 4ec9d6f4..682162e1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,6 +30,7 @@ end @time @safetestset "Linear Algebra Tests" include("linalg.jl") @time @safetestset "Adjoint Tests" include("adjoints.jl") @time @safetestset "Measurement Tests" include("measurements.jl") + @time @safetestset "Struct Copy Tests" include("struct_copy_test.jl") end if GROUP == "SymbolicIndexingInterface" || GROUP == "All" @@ -39,6 +40,7 @@ end if GROUP == "Downstream" activate_downstream_env() @time @safetestset "ODE Solve Tests" include("downstream/odesolve.jl") + @time @safetestset "ODE Solution Copy Tests" include("downstream/ode_solution_copy_test.jl") @time @safetestset "Event Tests with ArrayPartition" include("downstream/downstream_events.jl") @time @safetestset "Measurements and Units" include("downstream/measurements_and_units.jl") @time @safetestset "TrackerExt" include("downstream/TrackerExt.jl") diff --git a/test/struct_copy_test.jl b/test/struct_copy_test.jl new file mode 100644 index 00000000..209763a7 --- /dev/null +++ b/test/struct_copy_test.jl @@ -0,0 +1,151 @@ +using RecursiveArrayTools, Test + +# Test structures for struct-aware recursivecopy +struct SimpleStruct + a::Int + b::Float64 +end + +mutable struct MutableStruct + a::Vector{Float64} + b::Matrix{Int} + c::String +end + +struct NestedStruct + simple::SimpleStruct + mutable::MutableStruct + array::Vector{Int} +end + +struct ParametricStruct{T} + data::Vector{T} + metadata::T +end + +@testset "Struct recursivecopy tests" begin + + @testset "Simple immutable struct" begin + original = SimpleStruct(42, 3.14) + copied = recursivecopy(original) + + @test copied isa SimpleStruct + @test copied.a == original.a + @test copied.b == original.b + # Note: For immutable structs with only primitive types, Julia may optimize + # to use the same memory location, so we test functionality rather than identity + end + + @testset "Mutable struct with arrays" begin + original = MutableStruct([1.0, 2.0, 3.0], [1 2; 3 4], "test") + + # Should error for mutable structs + @test_throws ErrorException("recursivecopy for mutable structs is not currently implemented. Use deepcopy instead.") recursivecopy(original) + end + + @testset "Nested struct" begin + simple = SimpleStruct(10, 2.5) + # Create a nested struct with only immutable components + struct ImmutableNested + simple::SimpleStruct + array::Vector{Int} + name::String + end + + original = ImmutableNested(simple, [100, 200, 300], "nested") + copied = recursivecopy(original) + + @test copied isa ImmutableNested + @test copied.simple.a == original.simple.a + @test copied.simple.b == original.simple.b + @test copied.array == original.array + @test copied.name == original.name + + @test copied !== original + @test copied.array !== original.array + + # Test independence + original.array[1] = 999 + @test copied.array[1] == 100 # Should remain unchanged + end + + @testset "Parametric struct" begin + original = ParametricStruct([1, 2, 3], 42) + copied = recursivecopy(original) + + @test copied isa ParametricStruct{Int} + @test copied.data == original.data + @test copied.metadata == original.metadata + @test copied !== original + @test copied.data !== original.data + end + + @testset "Compatibility with existing types" begin + # Test that arrays still work + arr = [1, 2, 3] + copied_arr = recursivecopy(arr) + @test copied_arr == arr + @test copied_arr !== arr + + # Test that numbers still work + num = 42 + copied_num = recursivecopy(num) + @test copied_num == num + + # Test that strings still work + str = "hello" + copied_str = recursivecopy(str) + @test copied_str == str + end + + @testset "ArrayPartition with structs" begin + simple1 = SimpleStruct(1, 1.0) + simple2 = SimpleStruct(2, 2.0) + ap = ArrayPartition([simple1, simple2]) + copied_ap = recursivecopy(ap) + + @test copied_ap isa ArrayPartition + @test length(copied_ap.x) == length(ap.x) + @test copied_ap.x[1][1].a == ap.x[1][1].a + @test copied_ap !== ap + @test copied_ap.x[1] !== ap.x[1] + end + + @testset "Array dispatch still works correctly" begin + # Test that our struct method doesn't interfere with existing array methods + + # Arrays of numbers should use copy + num_array = [1, 2, 3] + copied_num = recursivecopy(num_array) + @test copied_num == num_array + @test copied_num !== num_array + + # Arrays of arrays should recursively copy + nested_array = [[1, 2], [3, 4]] + copied_nested = recursivecopy(nested_array) + @test copied_nested == nested_array + @test copied_nested !== nested_array + @test copied_nested[1] !== nested_array[1] + @test copied_nested[2] !== nested_array[2] + + # AbstractVectorOfArray should use its method + ap = ArrayPartition([1.0, 2.0], [3, 4]) + copied_ap = recursivecopy(ap) + @test copied_ap isa ArrayPartition + @test copied_ap.x[1] == ap.x[1] + @test copied_ap.x[1] !== ap.x[1] + + # Test that structs containing arrays still work + struct StructWithArrays + data::Vector{Vector{Int}} + metadata::String + end + + original_struct = StructWithArrays([[1, 2], [3, 4]], "test") + copied_struct = recursivecopy(original_struct) + @test copied_struct.data == original_struct.data + @test copied_struct.data !== original_struct.data + @test copied_struct.data[1] !== original_struct.data[1] + @test copied_struct.metadata == original_struct.metadata + end +end \ No newline at end of file