-
-
Notifications
You must be signed in to change notification settings - Fork 72
feat: Implement recursivecopy for structs
#468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
RomeoV
wants to merge
5
commits into
SciML:master
Choose a base branch
from
RomeoV:rv/deepcopy
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
88b9a25
feat: Implement `recursivecopy` for structs
RomeoV 148265e
Add tests, generated by Claude Code, checked by me
RomeoV 302f91f
Add downstream ODE solution recursivecopy tests
RomeoV 13dcfdd
Add Function to _is_basic_julia_type and update tests
RomeoV 9a43119
Improve ODE solution copy tests for independence verification
RomeoV File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,15 +3,29 @@ unrolled_foreach!(f, ::Tuple{}) = nothing | |
|
|
||
| """ | ||
| ```julia | ||
| recursivecopy(a::Union{AbstractArray{T, N}, AbstractVectorOfArray{T, N}}) | ||
| recursivecopy(a) | ||
| ``` | ||
|
|
||
| A recursive `copy` function. Acts like a `deepcopy` on arrays of arrays, but | ||
| like `copy` on arrays of scalars. | ||
| """ | ||
| function recursivecopy(a) | ||
| deepcopy(a) | ||
| like `copy` on arrays of scalars. For struct types, recursively copies each | ||
| field, creating new instances while preserving the struct type. | ||
|
|
||
| ## Examples | ||
|
|
||
| ```julia | ||
| # Basic array copying | ||
| arr = [[1, 2], [3, 4]] | ||
| copied = recursivecopy(arr) # New arrays at each level | ||
|
|
||
| # Struct copying | ||
| struct MyStruct | ||
| data::Vector{Float64} | ||
| metadata::String | ||
| end | ||
| original = MyStruct([1.0, 2.0], "test") | ||
| copied = recursivecopy(original) # New struct with new vector | ||
| ``` | ||
| """ | ||
| function recursivecopy(a::Union{StaticArraysCore.SVector, StaticArraysCore.SMatrix, | ||
| StaticArraysCore.SArray, Number}) | ||
| copy(a) | ||
|
|
@@ -35,6 +49,42 @@ function recursivecopy(a::AbstractVectorOfArray) | |
| return b | ||
| end | ||
|
|
||
| function _is_basic_julia_type(T) | ||
| # Check if this is a built-in Julia type that we should not handle as a user struct | ||
| # We check the module to identify Core/Base types vs user-defined types | ||
| mod = Base.parentmodule(T) | ||
| return T <: AbstractString || T <: Number || T <: Symbol || T <: Tuple || | ||
| T <: UnitRange || T <: StepRange || T <: Regex || T <: Function || | ||
| T === Nothing || T === Missing || | ||
| mod === Core || mod === Base | ||
| end | ||
|
|
||
| function recursivecopy(s::T) where {T} | ||
| # Only handle user-defined immutable structs. Many basic Julia types (String, Symbol, | ||
| # Tuple, etc.) are technically structs but should use copy() or return as-is. | ||
| if Base.isstructtype(T) && !_is_basic_julia_type(T) | ||
| if Base.ismutabletype(T) | ||
| error("recursivecopy for mutable structs is not currently implemented. Use deepcopy instead.") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the issue there? |
||
| else | ||
| # Handle immutable structs only | ||
| field_values = ntuple(fieldcount(T)) do i | ||
| field_value = getfield(s, i) | ||
| recursivecopy(field_value) | ||
| end | ||
| return T(field_values...) | ||
| end | ||
| elseif _is_basic_julia_type(T) | ||
| # For basic Julia types, use copy if available, otherwise return as-is (for immutable types) | ||
| if hasmethod(copy, Tuple{T}) | ||
| return copy(s) | ||
| else | ||
| return s # Immutable basic types like Symbol, Nothing, Missing don't need copying | ||
| end | ||
| else | ||
| deepcopy(s) | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| ```julia | ||
| recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T, N}) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| using OrdinaryDiffEq, RecursiveArrayTools, Test | ||
|
|
||
| @testset "ODE Solution recursivecopy tests" begin | ||
|
|
||
| @testset "Basic ODE solution copying" begin | ||
| # Define a simple ODE system | ||
| function simple_ode!(du, u, p, t) | ||
| du[1] = -0.5 * u[1] | ||
| du[2] = 0.3 * u[2] | ||
| end | ||
|
|
||
| u0 = [1.0, 2.0] | ||
| tspan = (0.0, 2.0) | ||
| prob = ODEProblem(simple_ode!, u0, tspan) | ||
| sol = solve(prob, Tsit5(), saveat=0.5) | ||
|
|
||
| # Test that we can copy the solution | ||
| copied_sol = recursivecopy(sol) | ||
|
|
||
| # Verify the solution structure is preserved | ||
| @test typeof(copied_sol) == typeof(sol) | ||
| @test copied_sol.t == sol.t | ||
| @test copied_sol.u == sol.u | ||
| @test copied_sol.retcode == sol.retcode | ||
|
|
||
| # Verify that arrays are independent copies | ||
| @test copied_sol !== sol | ||
| @test copied_sol.u !== sol.u | ||
| @test copied_sol.t !== sol.t | ||
|
|
||
| # Test that modifying one doesn't affect the other | ||
| if length(copied_sol.u) > 0 | ||
| original_value = sol.u[1][1] | ||
| copied_sol.u[1][1] = 999.0 | ||
| @test sol.u[1][1] == original_value # Original should be unchanged | ||
| end | ||
| end | ||
|
|
||
| @testset "ArrayPartition ODE solution copying" begin | ||
| # Use the Lorenz system from the existing tests | ||
| function lorenz!(du, u, p, t) | ||
| du.x[1][1] = 10.0 * (u.x[2][1] - u.x[1][1]) | ||
| du.x[1][2] = u.x[1][1] * (28.0 - u.x[2][1]) - u.x[1][2] | ||
| du.x[2][1] = u.x[1][1] * u.x[1][2] - (8/3) * u.x[2][1] | ||
| end | ||
|
|
||
| u0 = ArrayPartition([1.0, 0.0], [0.0]) | ||
| tspan = (0.0, 1.0) | ||
| prob = ODEProblem(lorenz!, u0, tspan) | ||
| sol = solve(prob, Tsit5(), saveat=0.2) | ||
|
|
||
| # Test that we can copy the ArrayPartition-based solution | ||
| copied_sol = recursivecopy(sol) | ||
|
|
||
| # Verify solution structure | ||
| @test typeof(copied_sol) == typeof(sol) | ||
| @test copied_sol.t == sol.t | ||
| @test length(copied_sol.u) == length(sol.u) | ||
|
|
||
| # Verify ArrayPartition structure is preserved | ||
| for i in eachindex(sol.u) | ||
| @test copied_sol.u[i] isa ArrayPartition | ||
| @test copied_sol.u[i].x[1] == sol.u[i].x[1] | ||
| @test copied_sol.u[i].x[2] == sol.u[i].x[2] | ||
|
|
||
| # Verify independence | ||
| @test copied_sol.u[i] !== sol.u[i] | ||
| @test copied_sol.u[i].x[1] !== sol.u[i].x[1] | ||
| @test copied_sol.u[i].x[2] !== sol.u[i].x[2] | ||
| end | ||
| end | ||
|
|
||
| @testset "Struct-based parameter copying in ODE" begin | ||
| # Create an ODE with struct-based parameters to test our struct copying | ||
| struct ODEParams | ||
| decay_rate::Float64 | ||
| growth_rate::Float64 | ||
| coefficients::Vector{Float64} | ||
| end | ||
|
|
||
| function parametric_ode!(du, u, params::ODEParams, t) | ||
| du[1] = -params.decay_rate * u[1] + params.coefficients[1] | ||
| du[2] = params.growth_rate * u[2] + params.coefficients[2] | ||
| end | ||
|
|
||
| params = ODEParams(0.5, 0.3, [0.1, 0.2]) | ||
| u0 = [1.0, 2.0] | ||
| tspan = (0.0, 1.0) | ||
| prob = ODEProblem(parametric_ode!, u0, tspan, params) | ||
| sol = solve(prob, Tsit5(), saveat=0.5) | ||
|
|
||
| # Test copying the solution (which contains the struct parameters) | ||
| copied_sol = recursivecopy(sol) | ||
|
|
||
| # Verify parameter struct is copied correctly | ||
| @test copied_sol.prob.p isa ODEParams | ||
| @test typeof(copied_sol) == typeof(sol) | ||
| @test copied_sol.prob.p.decay_rate == sol.prob.p.decay_rate | ||
| @test copied_sol.prob.p.growth_rate == sol.prob.p.growth_rate | ||
| @test copied_sol.prob.p.coefficients == sol.prob.p.coefficients | ||
|
|
||
| # Test that the main solution arrays are independent (most important for users) | ||
| original_u = sol.u[1][1] | ||
| copied_sol.u[1][1] = 888.0 | ||
| @test sol.u[1][1] == original_u # Solution data should be independent | ||
|
|
||
| # Note: ODE solution internal structures may have optimized sharing | ||
| # The key success is that recursivecopy works and solution data is independent | ||
| end | ||
|
|
||
| println("All ODE solution recursivecopy tests completed successfully!") | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| using RecursiveArrayTools, Test | ||
|
|
||
| # Test structures for struct-aware recursivecopy | ||
| struct SimpleStruct | ||
| a::Int | ||
| b::Float64 | ||
| end | ||
|
|
||
| mutable struct MutableStruct | ||
| a::Vector{Float64} | ||
| b::Matrix{Int} | ||
| c::String | ||
| end | ||
|
|
||
| struct NestedStruct | ||
| simple::SimpleStruct | ||
| mutable::MutableStruct | ||
| array::Vector{Int} | ||
| end | ||
|
|
||
| struct ParametricStruct{T} | ||
| data::Vector{T} | ||
| metadata::T | ||
| end | ||
|
|
||
| @testset "Struct recursivecopy tests" begin | ||
|
|
||
| @testset "Simple immutable struct" begin | ||
| original = SimpleStruct(42, 3.14) | ||
| copied = recursivecopy(original) | ||
|
|
||
| @test copied isa SimpleStruct | ||
| @test copied.a == original.a | ||
| @test copied.b == original.b | ||
| # Note: For immutable structs with only primitive types, Julia may optimize | ||
| # to use the same memory location, so we test functionality rather than identity | ||
| end | ||
|
|
||
| @testset "Mutable struct with arrays" begin | ||
| original = MutableStruct([1.0, 2.0, 3.0], [1 2; 3 4], "test") | ||
|
|
||
| # Should error for mutable structs | ||
| @test_throws ErrorException("recursivecopy for mutable structs is not currently implemented. Use deepcopy instead.") recursivecopy(original) | ||
| end | ||
|
|
||
| @testset "Nested struct" begin | ||
| simple = SimpleStruct(10, 2.5) | ||
| # Create a nested struct with only immutable components | ||
| struct ImmutableNested | ||
| simple::SimpleStruct | ||
| array::Vector{Int} | ||
| name::String | ||
| end | ||
|
|
||
| original = ImmutableNested(simple, [100, 200, 300], "nested") | ||
| copied = recursivecopy(original) | ||
|
|
||
| @test copied isa ImmutableNested | ||
| @test copied.simple.a == original.simple.a | ||
| @test copied.simple.b == original.simple.b | ||
| @test copied.array == original.array | ||
| @test copied.name == original.name | ||
|
|
||
| @test copied !== original | ||
| @test copied.array !== original.array | ||
|
|
||
| # Test independence | ||
| original.array[1] = 999 | ||
| @test copied.array[1] == 100 # Should remain unchanged | ||
| end | ||
|
|
||
| @testset "Parametric struct" begin | ||
| original = ParametricStruct([1, 2, 3], 42) | ||
| copied = recursivecopy(original) | ||
|
|
||
| @test copied isa ParametricStruct{Int} | ||
| @test copied.data == original.data | ||
| @test copied.metadata == original.metadata | ||
| @test copied !== original | ||
| @test copied.data !== original.data | ||
| end | ||
|
|
||
| @testset "Compatibility with existing types" begin | ||
| # Test that arrays still work | ||
| arr = [1, 2, 3] | ||
| copied_arr = recursivecopy(arr) | ||
| @test copied_arr == arr | ||
| @test copied_arr !== arr | ||
|
|
||
| # Test that numbers still work | ||
| num = 42 | ||
| copied_num = recursivecopy(num) | ||
| @test copied_num == num | ||
|
|
||
| # Test that strings still work | ||
| str = "hello" | ||
| copied_str = recursivecopy(str) | ||
| @test copied_str == str | ||
| end | ||
|
|
||
| @testset "ArrayPartition with structs" begin | ||
| simple1 = SimpleStruct(1, 1.0) | ||
| simple2 = SimpleStruct(2, 2.0) | ||
| ap = ArrayPartition([simple1, simple2]) | ||
| copied_ap = recursivecopy(ap) | ||
|
|
||
| @test copied_ap isa ArrayPartition | ||
| @test length(copied_ap.x) == length(ap.x) | ||
| @test copied_ap.x[1][1].a == ap.x[1][1].a | ||
| @test copied_ap !== ap | ||
| @test copied_ap.x[1] !== ap.x[1] | ||
| end | ||
|
|
||
| @testset "Array dispatch still works correctly" begin | ||
| # Test that our struct method doesn't interfere with existing array methods | ||
|
|
||
| # Arrays of numbers should use copy | ||
| num_array = [1, 2, 3] | ||
| copied_num = recursivecopy(num_array) | ||
| @test copied_num == num_array | ||
| @test copied_num !== num_array | ||
|
|
||
| # Arrays of arrays should recursively copy | ||
| nested_array = [[1, 2], [3, 4]] | ||
| copied_nested = recursivecopy(nested_array) | ||
| @test copied_nested == nested_array | ||
| @test copied_nested !== nested_array | ||
| @test copied_nested[1] !== nested_array[1] | ||
| @test copied_nested[2] !== nested_array[2] | ||
|
|
||
| # AbstractVectorOfArray should use its method | ||
| ap = ArrayPartition([1.0, 2.0], [3, 4]) | ||
| copied_ap = recursivecopy(ap) | ||
| @test copied_ap isa ArrayPartition | ||
| @test copied_ap.x[1] == ap.x[1] | ||
| @test copied_ap.x[1] !== ap.x[1] | ||
|
|
||
| # Test that structs containing arrays still work | ||
| struct StructWithArrays | ||
| data::Vector{Vector{Int}} | ||
| metadata::String | ||
| end | ||
|
|
||
| original_struct = StructWithArrays([[1, 2], [3, 4]], "test") | ||
| copied_struct = recursivecopy(original_struct) | ||
| @test copied_struct.data == original_struct.data | ||
| @test copied_struct.data !== original_struct.data | ||
| @test copied_struct.data[1] !== original_struct.data[1] | ||
| @test copied_struct.metadata == original_struct.metadata | ||
| end | ||
| end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.