Skip to content

Commit 1344893

Browse files
committed
update of hfo2 example and small changes of weight interface
1 parent b755848 commit 1344893

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

examples/atomistic/srs-vs-sme-hfo2.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using Pkg
2-
Pkg.develop(path="../../")
1+
#using Pkg
2+
#Pkg.develop(path="../../")
33

44
using StreamingSampling
55

@@ -129,7 +129,7 @@ vref_dict = Dict(:Hf => avg_energy_per_atom,
129129
:O => avg_energy_per_atom)
130130

131131
#Adjust reference energies (permanent change)
132-
adjust_energies(ds_train_rnd,vref_dict)
132+
adjust_energies!(ds_train_rnd, vref_dict)
133133

134134
# Define basis for fitting
135135
basis_fitting = ACE(species = [:Hf, :O],
@@ -139,28 +139,26 @@ basis_fitting = ACE(species = [:Hf, :O],
139139
csp = 1.0,
140140
r0 = 1.43,
141141
rcutoff = 4.4 );
142-
calc_descr!(ds_train_rnd, basis_fitting)
143-
calc_descr!(ds_test_rnd, basis_fitting)
142+
ds_train_rnd = calc_descr!(ds_train_rnd, basis_fitting)
143+
ds_test_rnd = calc_descr!(ds_test_rnd, basis_fitting)
144144

145145
# Initialize streaming sampling ################################################
146-
read_conf(x::Configuration) = x
147-
basis = ACE(species = [:C, :O, :H],
146+
basis = ACE(species = [:Hf, :O],
148147
body_order = 4,
149148
polynomial_degree = 8,
150149
wL = 2.0,
151150
csp = 1.0,
152151
r0 = 1.43,
153152
rcutoff = 4.4 );
154-
function create_feature(element::Vector; basis=basis)
155-
system = element[1]
153+
function create_feature(element::Configuration; basis=basis)
154+
system = get_system(element)
156155
feature = sum(compute_local_descriptors(system, basis))
157156
return feature
158157
end
159158
ws = compute_weights(ds_train_rnd.Configurations;
160-
read_element=read_element,
161159
create_feature=create_feature,
162160
chunksize=2000,
163-
subchunksize=200)
161+
subchunksize=4)
164162
open("ws-hfo2.jls", "w") do io
165163
serialize(io, ws)
166164
flush(io)
@@ -190,12 +188,12 @@ metrics = DataFrame([Any[] for _ in 1:length(metric_names)], metric_names)
190188
for j in 1:n_experiments
191189
println("Experiment $j")
192190

193-
global metrics
191+
global metrics, ds_train_rnd, ds_test_rnd
194192

195193
for n in sample_sizes
196194
# Sample training dataset using streaming weighted sampling ############
197195
train_inds = StatsBase.sample(1:length(ws), Weights(ws), n;
198-
replace=false, ordered=true))
196+
replace=false, ordered=true)
199197
#Load atomistic configurations
200198
train_ds = @views ds_train_rnd[train_inds]
201199
# Create result folder

src/Weights.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,14 @@ function compute_weights(file_paths::Vector{String};
3030
end
3131

3232
function compute_weights(A::Vector;
33-
read_element=read_element,
3433
create_feature=create_feature,
35-
chunksize=1000,
36-
subchunksize=100,
34+
chunksize=2000,
35+
subchunksize=200,
3736
buffersize=32,
3837
max=Inf,
3938
randomized=true,
4039
normalize=true)
4140
ch, N = chunk_iterator(A;
42-
read_element=read_element,
4341
chunksize=subchunksize,
4442
buffersize=buffersize,
4543
randomized=randomized)
@@ -60,8 +58,8 @@ end
6058

6159
function compute_weights(ch::Channel;
6260
create_feature=create_feature,
63-
chunksize=1000,
64-
subchunksize=100,
61+
chunksize=2000,
62+
subchunksize=200,
6563
max=Inf)
6664
# Step 1: Setup stage ######################################################
6765
@printf("Computing sampler weights...\n")

0 commit comments

Comments
 (0)