Skip to content

Commit 5b7259e

Browse files
In Convex Hull test (#15)
* add inconvexhull function * uncomment * fix tests * uncom tests * relocate scripts
1 parent 11465f3 commit 5b7259e

File tree

9 files changed

+79
-64
lines changed

9 files changed

+79
-64
lines changed
File renamed without changes.
File renamed without changes.

src/FullyConnected.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ function MLJFlux.train!(
118118
return training_loss / n_batches
119119
end
120120

121-
function train!(model, loss, opt_state, X, Y; batchsize=32, shuffle=true)
121+
function train!(model, loss, opt_state, X, Y; _batchsize=32, shuffle=true)
122+
batchsize = min(size(X, 2), _batchsize)
122123
X = X |> gpu
123124
Y = Y |> gpu
124125
data = Flux.DataLoader((X, Y),

src/L2O.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ export ArrowFile,
3636
make_convex,
3737
ConvexRule,
3838
relative_rmse,
39-
relative_mae
39+
relative_mae,
40+
inconvexhull
4041

4142
include("datasetgen.jl")
4243
include("csvrecorder.jl")
@@ -46,5 +47,6 @@ include("worst_case_iter.jl")
4647
include("FullyConnected.jl")
4748
include("nn_expression.jl")
4849
include("metrics.jl")
50+
include("inconvexhull.jl")
4951

5052
end

src/inconvexhull.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
3+
inconvexhull(training_set::Matrix{Float64}, test_set::Matrix{Float64})
4+
5+
Check if new points are inside the convex hull of the given points. Solves a linear programming problem to check if the points are inside the convex hull.
6+
"""
7+
function inconvexhull(training_set::Matrix{Float64}, test_set::Matrix{Float64}, solver)
8+
# Get the number of points and dimensions
9+
n, d = size(training_set)
10+
m, d = size(test_set)
11+
12+
# Create the model
13+
model = JuMP.Model(solver)
14+
15+
# Create the variables
16+
@variable(model, lambda[1:n, 1:m] >= 0)
17+
@constraint(model, convex_combination[i=1:m], sum(lambda[j, i] for j in 1:n) == 1)
18+
19+
# slack variables
20+
@variable(model, slack[1:m] >= 0)
21+
22+
# Create the constraints
23+
@constraint(model, in_convex_hull[i=1:m, k=1:d], sum(lambda[j, i] * training_set[j, k] for j in 1:n) == test_set[i, k] + slack[i])
24+
25+
# Create the objective
26+
@objective(model, Min, sum(slack[i] for i in 1:m))
27+
28+
# solve the model
29+
optimize!(model)
30+
31+
# return if the points are inside the convex hull
32+
return isapprox.(value.(slack), 0)
33+
end

test/inconvexhull.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
test_inconvexhull()
3+
4+
Test the inconvexhull function: inconvexhull(training_set::Matrix{Float64}, test_set::Matrix{Float64})
5+
"""
6+
function test_inconvexhull()
7+
@testset "inconvexhull" begin
8+
# Create the training set
9+
training_set = [0. 0; 1 0; 0 1; 1 1]
10+
11+
# Create the test set
12+
test_set = [0.5 0.5; -0.5 0.5; 0.5 -0.5; 0.0 0.5]
13+
14+
# Test the inconvexhull function
15+
@test inconvexhull(training_set, test_set, HiGHS.Optimizer) == [true, false, false, true]
16+
end
17+
end

test/nn_expression.jl

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -28,64 +28,3 @@ function test_flux_jump_basic()
2828
end
2929
end
3030
end
31-
32-
"""
33-
test_FullyConnected_jump()
34-
35-
Tests running a jump model with a FullyConnected Network expression.
36-
"""
37-
function test_FullyConnected_jump()
38-
for i in 1:10
39-
X = rand(100, 3)
40-
Y = rand(100, 1)
41-
42-
nn = MultitargetNeuralNetworkRegressor(;
43-
builder=FullyConnectedBuilder([8, 8, 8]),
44-
rng=123,
45-
epochs=100,
46-
optimiser=optimiser,
47-
acceleration=CUDALibs(),
48-
batch_size=32,
49-
)
50-
51-
mach = machine(nn, X, y)
52-
fit!(mach; verbosity=2)
53-
54-
flux_model = mach.fitresult[1]
55-
56-
model = JuMP.Model(Gurobi.Optimizer)
57-
58-
@variable(model, x[i = 1:3]>= 2.3)
59-
60-
ex = flux_model(x)[1]
61-
62-
# @constraint(model, ex >= -100.0)
63-
@constraint(model, sum(x) <= 10)
64-
65-
@objective(model, Min, ex)
66-
67-
JuMP.optimize!(model)
68-
69-
@test termination_status(model) === OPTIMAL
70-
if flux_model(value.(x))[1] <= 1.0
71-
@test isapprox(flux_model(value.(x))[1], value(ex); atol=0.01)
72-
else
73-
@test isapprox(flux_model(value.(x))[1], value(ex); rtol=0.001)
74-
end
75-
end
76-
end
77-
78-
function print_conflict!(model)
79-
JuMP.compute_conflict!(model)
80-
ctypes = list_of_constraint_types(model)
81-
for (F, S) in ctypes
82-
cons = all_constraints(model, F, S)
83-
for i in eachindex(cons)
84-
isassigned(cons, i) || continue
85-
con = cons[i]
86-
cst = MOI.get(model, MOI.ConstraintConflictStatus(), con)
87-
cst == MOI.IN_CONFLICT && @info JuMP.name(con) con
88-
end
89-
end
90-
return nothing
91-
end

test/runtests.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using Flux
1313
using MLJ
1414
using CSV
1515
using DataFrames
16+
using Optimisers
1617

1718
using NonconvexNLopt
1819

@@ -29,9 +30,14 @@ include(joinpath(test_dir, "test_flux_forecaster.jl"))
2930

3031
include(joinpath(test_dir, "nn_expression.jl"))
3132

33+
include(joinpath(test_dir, "inconvexhull.jl"))
34+
3235
@testset "L2O.jl" begin
36+
test_fully_connected()
37+
test_flux_jump_basic()
38+
test_inconvexhull()
39+
3340
mktempdir() do path
34-
test_flux_jump_basic()
3541
test_problem_iterator(path)
3642
test_worst_case_problem_iterator(path)
3743
file_in, file_out = test_pglib_datasetgen(path, "pglib_opf_case5_pjm", 20)

test/test_flux_forecaster.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,20 @@ function test_flux_forecaster(file_in::AbstractString, file_out::AbstractString)
3838
rm(file_out)
3939
end
4040
end
41+
42+
# Test the Flux.jl forecaster outside MLJ.jl
43+
function test_fully_connected(;num_samples::Int=100, num_features::Int=10)
44+
X = rand(num_features, num_samples)
45+
y = rand(1, num_samples)
46+
47+
model = FullyConnected(10, [10, 10], 1)
48+
49+
# Train the model
50+
optimizer = Optimisers.Adam()
51+
opt_state = Optimisers.setup(optimizer, model)
52+
L2O.train!(model, Flux.mse, opt_state, X, y)
53+
54+
# Make predictions
55+
predictions = model(X)
56+
@test predictions isa Array
57+
end

0 commit comments

Comments
 (0)