Skip to content

Commit 78e50d8

Browse files
committed
wip
1 parent d81e625 commit 78e50d8

File tree

4 files changed

+46
-1
lines changed

4 files changed

+46
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ SCIP = "82193955-e24f-5292-bf16-6f2c5261a85f"
2525
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
2626
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2727
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
28+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2829

2930
[compat]
3031
ConstrainedShortestPaths = "0.6.0"
@@ -47,6 +48,7 @@ Requires = "1.3.0"
4748
SimpleWeightedGraphs = "1.4"
4849
SparseArrays = "1"
4950
Statistics = "1.11.1"
51+
StatsBase = "0.34.4"
5052
julia = "1.6"
5153

5254
[extras]

src/PortfolioOptimization/PortfolioOptimization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Distributions: Uniform, Bernoulli
66
using Flux: Chain, Dense
77
using Ipopt: Ipopt
8-
using JuMP: @variable, @objective, @constraint, optimize!, value
8+
using JuMP: @variable, @objective, @constraint, optimize!, value, Model
99
using LinearAlgebra: I
1010
using Random: MersenneTwister
1111

src/Utils/Utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using JuMP: Model
77
using LinearAlgebra: dot
88
using SCIP: SCIP
99
using SimpleWeightedGraphs: SimpleWeightedDiGraph
10+
using StatsBase: StatsBase
1011

1112
include("data_sample.jl")
1213
include("interface.jl")
@@ -24,5 +25,6 @@ export grid_graph, get_path, path_to_matrix
2425
export neg_tensor, squeeze_last_dims, average_tensor
2526
export scip_model, highs_model
2627
export objective_value
28+
export compute_normalizer
2729

2830
end

src/Utils/data_sample.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,44 @@ $TYPEDFIELDS
1818
"instance object (optional)"
1919
instance::I = nothing
2020
end
21+
22+
function _transform(t, sample::DataSample; kwargs...)
23+
(; instance, x, θ_true, y_true) = sample
24+
return DataSample(; instance, x=StatsBase.transform(t, x; kwargs...), θ_true, y_true)
25+
end
26+
27+
function _reconstruct(t, sample::DataSample; kwargs...)
28+
(; instance, x, θ_true, y_true) = sample
29+
return DataSample(; instance, x=StatsBase.reconstruct(t, x; kwargs...), θ_true, y_true)
30+
end
31+
32+
"""
33+
$TYPEDSIGNATURES
34+
35+
Compute the mean and standard deviation of the features in the dataset.
36+
"""
37+
function StatsBase.fit(transform_type, dataset::AbstractVector{<:DataSample}; kwargs...)
38+
x = hcat([d.x for d in dataset]...)
39+
return StatsBase.fit(transform_type, x; kwargs...)
40+
end
41+
42+
"""
43+
$TYPEDSIGNATURES
44+
45+
Transform the features in the dataset.
46+
"""
47+
function StatsBase.transform(t, dataset::AbstractVector{<:DataSample})
48+
return map(dataset) do d
49+
(; instance, x, θ_true, y_true) = d
50+
DataSample(; instance, x=StatsBase.transform(t, x), θ_true, y_true)
51+
end
52+
end
53+
54+
# TODO: reconstruct, transform!, reconstruct!
55+
56+
function StatsBase.reconstruct(t, dataset::AbstractVector{<:DataSample})
57+
return map(dataset) do d
58+
(; instance, x, θ_true, y_true) = d
59+
DataSample(; instance, x=StatsBase.reconstruct(t, x), θ_true, y_true)
60+
end
61+
end

0 commit comments

Comments
 (0)