1- using Pkg
2- Pkg. develop(path= " ../../" )
1+ # using Pkg
2+ # Pkg.develop(path="../../")
33
44using StreamingSampling
55
@@ -129,7 +129,7 @@ vref_dict = Dict(:Hf => avg_energy_per_atom,
129129 :O => avg_energy_per_atom)
130130
131131# Adjust reference energies (permanent change)
132- adjust_energies(ds_train_rnd,vref_dict)
132+ adjust_energies! (ds_train_rnd, vref_dict)
133133
134134# Define basis for fitting
135135basis_fitting = ACE(species = [:Hf, :O],
@@ -139,28 +139,26 @@ basis_fitting = ACE(species = [:Hf, :O],
139139 csp = 1.0 ,
140140 r0 = 1.43 ,
141141 rcutoff = 4.4 );
142- calc_descr!(ds_train_rnd, basis_fitting)
143- calc_descr!(ds_test_rnd, basis_fitting)
142+ ds_train_rnd = calc_descr!(ds_train_rnd, basis_fitting)
143+ ds_test_rnd = calc_descr!(ds_test_rnd, basis_fitting)
144144
145145# Initialize streaming sampling ################################################
146- read_conf(x:: Configuration ) = x
147- basis = ACE(species = [:C, :O, :H],
146+ basis = ACE(species = [:Hf, :O],
148147 body_order = 4 ,
149148 polynomial_degree = 8 ,
150149 wL = 2.0 ,
151150 csp = 1.0 ,
152151 r0 = 1.43 ,
153152 rcutoff = 4.4 );
154- function create_feature(element:: Vector ; basis= basis)
155- system = element[ 1 ]
153+ function create_feature(element:: Configuration ; basis= basis)
154+ system = get_system( element)
156155 feature = sum(compute_local_descriptors(system, basis))
157156 return feature
158157end
159158ws = compute_weights(ds_train_rnd. Configurations;
160- read_element= read_element,
161159 create_feature= create_feature,
162160 chunksize= 2000 ,
163- subchunksize= 200 )
161+ subchunksize= 4 )
164162open(" ws-hfo2.jls" , " w" ) do io
165163 serialize(io, ws)
166164 flush(io)
@@ -190,12 +188,12 @@ metrics = DataFrame([Any[] for _ in 1:length(metric_names)], metric_names)
190188for j in 1 : n_experiments
191189 println(" Experiment $j " )
192190
193- global metrics
191+ global metrics, ds_train_rnd, ds_test_rnd
194192
195193 for n in sample_sizes
196194 # Sample training dataset using streaming weighted sampling ############
197195 train_inds = StatsBase. sample(1 : length(ws), Weights(ws), n;
198- replace= false , ordered= true ))
196+ replace= false , ordered= true )
199197 # Load atomistic configurations
200198 train_ds = @views ds_train_rnd[train_inds]
201199 # Create result folder
0 commit comments