11import LinearAlgebra
2+ import Base./
23
34"""
45 VariableIndexedVector(data::Vector{T}, index::Vector{<:Forecast})
@@ -34,6 +35,9 @@ Base.length(v::VariableIndexedVector) = length(v.data)
3435Base. getindex(v:: VariableIndexedVector , i:: Int ) = v. data[i]
3536Base. setindex!(v:: VariableIndexedVector , val, i:: Int ) = (v. data[i] = val)
3637
38+ # define Base rigth divide function (v / 2)
39+ / (v:: VariableIndexedVector , i:: Number ) = VariableIndexedVector(v. data / i, v. index)
40+
3741# define dot product of two VariableIndexedVectors
3842function LinearAlgebra. dot(v1:: VariableIndexedVector , v2:: VariableIndexedVector )
3943 @assert length(v1) == length(v2) " Vectors must have the same length"
@@ -98,6 +102,23 @@ struct VariableIndexedMatrix{T} <: AbstractMatrix{T}
98102 @assert length(unique(row_index)) == length(row_index) " Variables must be unique"
99103 new{T}(data, row_index)
100104 end
105+
106+ function VariableIndexedMatrix{T}(:: UndefInitializer , index:: Vector{<:Forecast} , n:: Real ) where T
107+ return new{T}(Matrix{T}(undef, length(index), n), index)
108+ end
109+
110+ function VariableIndexedMatrix{T}(:: Nothing , index:: Vector{<:Forecast} , n:: Real ) where T
111+ return new{T}(Matrix{T}(nothing , length(index), n), index)
112+ end
113+ end
114+
115+ # helper to find index of a Forecast variable
116+ function _get_idx(m:: VariableIndexedMatrix , var:: Forecast )
117+ i = findfirst(isequal(var), m. row_index)
118+ if isnothing(i)
119+ throw(KeyError(var))
120+ end
121+ return i
101122end
102123
103124Base. size(m:: VariableIndexedMatrix ) = size(m. data)
@@ -106,11 +127,44 @@ Base.size(m::VariableIndexedMatrix) = size(m.data)
106127Base. getindex(m:: VariableIndexedMatrix , i:: Int , j:: Int ) = m. data[i, j]
107128Base. setindex!(m:: VariableIndexedMatrix , val, i:: Int , j:: Int ) = (m. data[i, j] = val)
108129
109- # column lookup (get column 2) {M[:, 2]}
130+ # column lookup (get column 2) {M[2]}
110131function Base. getindex(m:: VariableIndexedMatrix , c:: Int )
111132 return VariableIndexedVector(m. data[:, c], m. row_index)
112133end
113134
135+ # row lookup (get values from variable) {M[forecast_var]}
136+ function Base. getindex(m:: VariableIndexedMatrix , var:: Forecast )
137+ return m. data[_get_idx(m, var), :]
138+ end
139+
140+ # multi-row lookup {M[[f_var_1, f_var_2]]}
141+ function Base. getindex(m:: VariableIndexedMatrix , vars:: Vector{<:Forecast} )
142+ return VariableIndexedMatrix(m. data[[_get_idx(m, var) for var in vars], :], vars)
143+ end
144+
145+ # set column 2 values for all var indices {M[2] = [1,2,3]}
146+ function Base. setindex!(m:: VariableIndexedMatrix , values:: VariableIndexedVector , c:: Int )
147+ m. data[:, c] = values[m. row_index]
148+ end
149+
150+ # set values for a single var index {M[forecast_var] = [1,2,3]}
151+ function Base. setindex!(m:: VariableIndexedMatrix , values:: Vector , var:: Forecast )
152+ m. data[_get_idx(m, var), :] = values
153+ end
154+
155+ # set values for a subset of var indices {M[[f_var_1, f_var_2]] = [[1 2 3]; [4 5 6]]}
156+ function Base. setindex!(m:: VariableIndexedMatrix , values:: Matrix , vars:: Vector{<:Forecast} )
157+ m. data[[_get_idx(m, var) for var in vars], :] = values
158+ end
159+
160+ # define sum of matrix by summing all values for each variable
161+ function Base. sum(m:: VariableIndexedMatrix )
162+ return VariableIndexedVector(sum(m. data, dims= 2 )[:, 1 ], m. row_index)
163+ end
164+
165+ # define Base rigth divide function (M / 2)
166+ / (m:: ApplicationDrivenLearning.VariableIndexedMatrix , i:: Number ) = ApplicationDrivenLearning. VariableIndexedMatrix(m. data / i, m. row_index)
167+
114168# define dot product of a VariableIndexedVectors and a VariableIndexedMatrix
115169function LinearAlgebra. dot(v1:: VariableIndexedVector , m2:: VariableIndexedMatrix )
116170 @assert length(v1) == size(m2, 1 ) " Vector must have the same length as the number of rows in matrix"
0 commit comments