Skip to content

Commit db3fd06

Browse files
committed
Update of rmd17 example
1 parent dd6f3c2 commit db3fd06

File tree

6 files changed

+42
-36
lines changed

6 files changed

+42
-36
lines changed

examples/atomistic/srs-vs-sme-aspirin-rmd17.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ using StreamingSampling
66
include("utils/utils.jl")
77

88
# Define paths and create experiment folder
9-
train_path = ["data/md17/aspirin-train.xyz"]
10-
test_path = ["data/md17/aspirin-test.xyz"]
11-
res_path = "results-aspirin-md17/"
9+
train_path = ["data/rmd17/aspirin-train.xyz"]
10+
test_path = ["data/rmd17/aspirin-test.xyz"]
11+
res_path = "results-aspirin-rmd17/"
1212
run(`mkdir -p $res_path`)
1313

1414
# Initialize streaming sampling ################################################
@@ -101,33 +101,39 @@ for j in 1:n_experiments
101101
chunksize=m,
102102
buffersize=1,
103103
randomized=true)
104-
_, test_inds = take!(ch)
104+
cs, test_inds = take!(ch)
105105
close(ch)
106-
test_inds = sort(test_inds)
107-
test_confs = get_confs(test_path, test_inds)
108-
test_ds = calc_descr(test_confs, basis_fitting)
106+
test_confs = []
107+
for c in cs
108+
system, energy, forces = c
109+
conf = Configuration(system, Energy(energy),
110+
Forces([Force(f) for f in forces]))
111+
push!(test_confs, conf)
112+
end
113+
ds_test = DataSet(test_confs)
114+
ds_test = calc_descr!(ds_test, basis_fitting)
109115
open("test-ds-aspirin-rmd17.jls", "w") do io
110-
serialize(io, test_ds)
111-
flush(io)
116+
serialize(io, ds_test)
117+
flush(io)
112118
end
113-
#test_ds = deserialize("test-ds-aspirin-rmd17.jls")
119+
#ds_test = deserialize("test-ds-aspirin-rmd17.jls")
114120

115121
for n in sample_sizes
116122
# Sample training dataset using streaming weighted sampling ############
117123
train_inds = StatsBase.sample(1:length(ws), Weights(ws), n;
118-
replace=false, ordered=true))
124+
replace=false, ordered=true)
119125
#Load atomistic configurations
120-
train_confs = get_confs(train_path, train_inds)
126+
ds_train = get_confs(train_path, read_element, train_inds)
121127
#Adjust reference energies (permanent change)
122-
adjust_energies(train_confs, vref_dict)
128+
adjust_energies!(ds_train, vref_dict)
123129
# Compute dataset with energy and force descriptors
124-
train_ds = calc_descr(train_confs, basis_fitting)
130+
ds_train = calc_descr!(ds_train, basis_fitting)
125131
# Create result folder
126132
curr_sampler = "sws"
127133
exp_path = "$res_path/$j-$curr_sampler-n$n/"
128134
run(`mkdir -p $exp_path`)
129135
# Fit and save results
130-
metrics_j = fit(exp_path, train_ds, test_ds, basis_fitting; vref_dict=vref_dict)
136+
metrics_j = fit(exp_path, ds_train, ds_test, basis_fitting; vref_dict=vref_dict)
131137
metrics_j = merge(OrderedDict("exp_number" => j,
132138
"method" => "$curr_sampler",
133139
"batch_size_prop" => n/N,
@@ -142,17 +148,17 @@ for j in 1:n_experiments
142148
train_inds = randperm(N)[1:n]
143149

144150
#Load atomistic configurations
145-
train_confs = get_confs(train_path, train_inds)
151+
ds_train = get_confs(train_path, read_element, train_inds)
146152
#Adjust reference energies (permanent change)
147-
adjust_energies(train_confs, vref_dict)
153+
adjust_energies!(ds_train, vref_dict)
148154
# Compute dataset with energy and force descriptors
149-
train_ds = calc_descr(train_confs, basis_fitting)
155+
ds_train = calc_descr!(ds_train, basis_fitting)
150156
# Create result folder
151157
curr_sampler = "srs"
152158
exp_path = "$res_path/$j-$curr_sampler-n$n/"
153159
run(`mkdir -p $exp_path`)
154160
# Fit and save results
155-
metrics_j = fit(exp_path, train_ds, test_ds, basis_fitting; vref_dict=vref_dict)
161+
metrics_j = fit(exp_path, ds_train, ds_test, basis_fitting; vref_dict=vref_dict)
156162
metrics_j = merge(OrderedDict("exp_number" => j,
157163
"method" => "$curr_sampler",
158164
"batch_size_prop" => n/N,

examples/atomistic/utils/fitting-utils.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
function get_confs(path, inds)
1+
function get_confs(path, read_element, inds)
22
confs = []
3-
ch, N = chunk_iterator(train_path; chunksize=1000, randomized=false)
3+
ch, N = chunk_iterator(path;
4+
read_element=read_element,
5+
chunksize=1000,
6+
randomized=false)
47
k = 1
58
for (c, ci) in ch
69
j = 1

examples/atomistic/utils/plot-err-per-sample.jl

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function plot_err_per_sample(res_path, metrics_filename)
44
sort!(df, [:batch_size])
55

66
srs = filter(:method => ==("srs"), df)
7-
sme = filter(:method => ==("sme"), df)
7+
ss = filter(:method => ==("sws"), df)
88

99
# ---------------- Percent formatting (round UP, fixed) ----------------
1010
# ≥ 1% -> ceil to integer (no decimals)
@@ -24,9 +24,9 @@ function plot_err_per_sample(res_path, metrics_filename)
2424
end
2525

2626
# ---------------- X tick labels ----------------
27-
xs = sme.batch_size
27+
xs = ss.batch_size
2828
xtick_labels = [string(bs, "\n", format_percent_roundup(prop))
29-
for (bs, prop) in zip(sme.batch_size, sme.batch_size_prop)]
29+
for (bs, prop) in zip(ss.batch_size, ss.batch_size_prop)]
3030

3131
# ---------------- Colors ----------------
3232
black = RGB(0,0,0)
@@ -69,14 +69,14 @@ function plot_err_per_sample(res_path, metrics_filename)
6969
)
7070

7171
pE_bottom = plot(
72-
sme.batch_size, sme.e_test_mae;
72+
ss.batch_size, ss.e_test_mae;
7373
color = red, lw = 5.5, marker = :utriangle,
7474
xlabel = "Training Dataset Size (Sample Size)",
7575
ylabel = "E MAE | eV/atom",
76-
label = "SME",
76+
label = "SWS",
7777
xticks = (xs, xtick_labels),
7878
legend = :topright,
79-
ylims = padlims(sme.e_test_mae),
79+
ylims = padlims(ss.e_test_mae),
8080
)
8181

8282
energy_plot = plot(pE_top, pE_bottom; layout=(2,1), size=(1100,1100))
@@ -93,22 +93,18 @@ function plot_err_per_sample(res_path, metrics_filename)
9393
)
9494

9595
pF_bottom = plot(
96-
sme.batch_size, sme.f_test_mae;
96+
ss.batch_size, ss.f_test_mae;
9797
color = red, lw = 5.5, marker = :utriangle,
9898
xlabel = "Training Dataset Size (Sample Size)",
9999
ylabel = "F MAE | eV/Å",
100-
label = "SME",
100+
label = "SWS",
101101
xticks = (xs, xtick_labels),
102102
legend = :topright,
103-
ylims = padlims(SME.f_test_mae),
103+
ylims = padlims(ss.f_test_mae),
104104
)
105105

106106
force_plot = plot(pF_top, pF_bottom; layout=(2,1), size=(1100,1100))
107107
savefig(force_plot, "$res_path/f_test_mae_by_sample.pdf")
108-
109-
println("✅ Saved:")
110-
println(" - e_test_mae_by_sample.pdf")
111-
println(" - f_test_mae_by_sample.pdf")
112108
end
113109

114110
function plot_err_per_sample_2(res_path, metrics_filename)

examples/atomistic/utils/subtract-peratom-e.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function subtract_peratom_e(config::Configuration, vref_dict)
1010
Energy(new_e,e_unit)
1111
end
1212

13-
function adjust_energies(ds, vref_dict)
13+
function adjust_energies!(ds, vref_dict)
1414
for config in ds
1515
new_energy = subtract_peratom_e(config,vref_dict)
1616
config.data[Energy] = new_energy

examples/atomistic/utils/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DataFrames
55
using DelimitedFiles
66
using Determinantal
77
using InteratomicPotentials
8+
using LinearAlgebra
89
using LowRankApprox
910
using Measures
1011
using OrderedCollections

src/Weights.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ function compute_chunk_weights(features::Matrix{Float64})
171171
# Form an L-ensemble based on the kernel matrix K
172172
dpp = EllEnsemble(K)
173173
# Scale so that the expected size is 1
174-
rescale!(dpp, 1)
174+
rescale!(dpp, N ÷ 2)
175175
# Compute inclusion probabilities.
176176
inclusion_probs = Determinantal.inclusion_prob(dpp)
177177
return inclusion_probs

0 commit comments

Comments
 (0)