diff --git a/src/utils/wrap_tuples.jl b/src/utils/wrap_tuples.jl index 082b642f..21210085 100644 --- a/src/utils/wrap_tuples.jl +++ b/src/utils/wrap_tuples.jl @@ -10,7 +10,7 @@ struct WrappedTuples{T <: AbstractVector{<:NamedTuple}} <: AbstractVector{NamedT end # Required methods for AbstractVector -Base.size(w::WrappedTuples) = (length(w.data), length(first(w.data))) +Base.size(w::WrappedTuples) = (length(w.data),) Base.getindex(w::WrappedTuples, i::Int) = w.data[i] Base.getindex(w::WrappedTuples, r::AbstractRange) = WrappedTuples(w.data[r]) Base.IndexStyle(::Type{<:WrappedTuples}) = IndexLinear() @@ -33,8 +33,9 @@ function Base.propertynames(w::WrappedTuples, private::Bool = false) return (:data,) ∪ propertynames(first(w.data), private) end function Base.Matrix(w::WrappedTuples) - n, m = size(w) + n = length(w) fields = propertynames(first(w.data)) + m = length(fields) T = promote_type(map(f -> eltype(getproperty(w, f)), fields)...) mat = Array{T}(undef, n, m) for (j, f) in enumerate(fields) diff --git a/test/runtests.jl b/test/runtests.jl index a4bac5f9..a772ca70 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,7 @@ include("test_compute_loss.jl") include("test_loss_fn.jl") include("test_show_train.jl") include("test_show_generic_hybrid.jl") +include("test_wrap_tuples.jl") @testset "LinearHM" begin # test model instantiation diff --git a/test/test_wrap_tuples.jl b/test/test_wrap_tuples.jl new file mode 100644 index 00000000..b079d3f3 --- /dev/null +++ b/test/test_wrap_tuples.jl @@ -0,0 +1,55 @@ +using Test +using EasyHybrid: WrappedTuples + +@testset "WrappedTuples" begin + vec = [(a = 1, b = 2.0), (a = 3, b = 4.0)] + wt = WrappedTuples(vec) + + # Basic properties + @test typeof(wt) <: AbstractVector{NamedTuple} + @test size(wt) == (2,) + @test length(wt) == 2 + + # Indexing + @test wt[1] == vec[1] + @test wt[1:1] isa WrappedTuples + @test wt[1:1].data == vec[1:1] + + # Iteration + @test collect(wt) == vec + # Test iterate explicitly + @test iterate(wt) == (vec[1], 2) + @test iterate(wt, 2) == (vec[2], 3) + @test iterate(wt, 3) === nothing + + result = NamedTuple[] + for item in wt + push!(result, item) + end + @test result == vec + + # Index style + @test IndexStyle(WrappedTuples) isa IndexLinear + + # Dot-access to fields + @test wt.a == [1, 3] + @test wt.b == [2.0, 4.0] + + # Keys and propertynames + @test keys(wt) == propertynames(vec[1]) + pn = propertynames(wt) + @test :data in pn && :a in pn && :b in pn + + # Matrix conversion (checks promotion and column layout) + M = Matrix(wt) + @test size(M) == (2, 2) + @test M[1, 1] == 1.0 && M[2, 1] == 3.0 && M[1, 2] == 2.0 && M[2, 2] == 4.0 + + # Missing field raises Exception + @test_throws Exception wt.x + + # Slicing preserves behavior + sub = wt[2:2] + @test sub.a == [3] + @test length(sub) == 1 +end