Skip to content

Commit f02113f

Browse files
authored
update HNSW to use scale_modification
Add scale_modification in HNSW to have HubNSW
1 parent c4b9b1e commit f02113f

File tree

1 file changed

+39
-24
lines changed

1 file changed

+39
-24
lines changed

src/bin/annembed.rs

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,22 @@
77
//!
88
//! hnsw is an optional subcommand to change default parameters of the Hnsw structure. See [hnsw_rs](https://crates.io/crates/hnsw_rs).
99
//!
10-
//! - Parameters for embedding.
11-
//! The options are optional and give access to some fields of the [EmbedderParams] structure.
12-
//!
13-
//! --batch : optional, a integer giving the number of batch to run. Default to 15.
14-
//! --stepg : optional, a float value , initial gradient step, default is 2.
15-
//! --scale : optional, a float value, scale modification factor, default is 1.
16-
//! --nbsample : optional, a number of edge sampling , default is 10
17-
//! --layer : optional, in case of hierarchical embedding num of the lower layer we consider to run preliminary step.
18-
//! default is set to 0 meaning one pass embedding.
19-
//! --dim : optional, dimension of the embedding , default to 2.
20-
//!
21-
//! --quality : optional, asks for quality estimation.
22-
//! --sampling : optional, for large data defines the fraction of sampled data as 1./sampling
10+
//! - Parameters for embedding part are all optional The options give access to some fields of the [EmbedderParams] structure.
11+
//! --batch : optional, a integer giving the number of batch to run. Default to 15.
12+
//! --stepg : optional, a float value , initial gradient step, default is 2.
13+
//! --scale : optional, a float value, scale modification factor, default is 1.
14+
//! --nbsample : optional, a number of edge sampling , default is 10
15+
//! --layer : optional, in case of hierarchical embedding num of the lower layer we consider to run preliminary step.
16+
//! default is set to 0 meaning one pass embedding
17+
//! --dim : optional, dimension of the embedding , default to 2.
18+
//! --quality : optional, asks for quality estimation
19+
//! --sampling : optional, for large data defines the fraction of sampled data as 1./sampling
2320
//!
2421
//! - Parameters for the hnsw subcommand. For more details see [hnsw_rs](https://crates.io/crates/hnsw_rs).
25-
//! --nbconn : defines the number of connections by node in a layer. Can range from 4 to 64 or more if necessary and enough memory.
26-
//! --dist : name of distance to use: "DistL1", "DistL2", "DistCosine", "DistJeyffreys".
27-
//! --ef : controls the with of the search, a good guess is between 24 and 64 or more if necessary.
28-
//! --knbn : the number of nodes to use in retrieval requests.
22+
//! --nbconn : defines the number of connections by node in a layer. Can range from 4 to 64 or more if necessary and enough memory
23+
//! --dist : name of distance to use: "DistL1", "DistL2", "DistCosine", "DistJeyffreys"
24+
//! --ef : controls the with of the search, a good guess is between 24 and 64 or more if necessay
25+
//! --knbn : the number of nodes to use in retrieval requests.
2926
//!
3027
//! The csv file must have one record by vector to embed. The default delimiter is ','.
3128
//! The output is a csv file with embedded vectors.
@@ -55,25 +52,29 @@ pub struct HnswParams {
5552
knbn: usize,
5653
/// distance to use in Hnsw. Default is "DistL2". Other choices are "DistL1", "DistCosine", DistJeffreys
5754
distance: String,
55+
//scale_modification factor, must be [0.2, 1]
56+
scale_modification : f64,
5857
} // end of struct HnswParams
5958

6059
impl HnswParams {
61-
pub fn my_default() -> Self {
60+
pub fn default() -> Self {
6261
HnswParams {
6362
max_conn: 48,
6463
ef_c: 400,
6564
knbn: 10,
6665
distance: String::from("DistL2"),
66+
scale_modification: 1.0,
6767
}
6868
}
6969

7070
#[allow(unused)]
71-
pub fn new(max_conn: usize, ef_c: usize, knbn: usize, distance: String) -> Self {
71+
pub fn new(max_conn: usize, ef_c: usize, knbn: usize, distance: String, scale_modification: f64) -> Self {
7272
HnswParams {
7373
max_conn,
7474
ef_c,
7575
knbn,
7676
distance,
77+
scale_modification,
7778
}
7879
}
7980
} // end impl block
@@ -98,10 +99,11 @@ impl Default for QualityParams {
9899
fn parse_hnsw_cmd(matches: &ArgMatches) -> Result<HnswParams, anyhow::Error> {
99100
log::debug!("in parse_hnsw_cmd");
100101

101-
let mut hnswparams = HnswParams::my_default();
102+
let mut hnswparams = HnswParams::default();
102103
hnswparams.max_conn = *matches.get_one::<usize>("nbconn").unwrap();
103104
hnswparams.ef_c = *matches.get_one::<usize>("ef").unwrap();
104105
hnswparams.knbn = *matches.get_one::<usize>("knbn").unwrap();
106+
hnswparams.scale_modification = *matches.get_one::<f64>("scale_modification").unwrap();
105107

106108
match matches.get_one::<String>("dist") {
107109
Some(str) => match str.as_str() {
@@ -169,6 +171,7 @@ pub fn main() {
169171
let embedparams: EmbedderParams;
170172
//
171173
let hnswcmd = Command::new("hnsw")
174+
.about("Build HNSW graph")
172175
.arg(Arg::new("dist")
173176
.long("dist")
174177
.short('d')
@@ -193,15 +196,25 @@ pub fn main() {
193196
.required(true)
194197
.action(ArgAction::Set)
195198
.value_parser(clap::value_parser!(usize))
196-
.help("search factor"));
199+
.help("search factor"))
200+
.arg(Arg::new("scale_modification")
201+
.long("scale_modify_f")
202+
.help("scale modification factor in HNSW or HubNSW, must be in [0.2,1]")
203+
.value_name("scale_modify")
204+
.default_value("1.0")
205+
.action(ArgAction::Set)
206+
.value_parser(clap::value_parser!(f64))
207+
);
197208

198209
//
199210
// Now the command line
200211
// ===================
201212
//
202213
let matches = Command::new("annembed")
203214
// .subcommand_required(true)
215+
.version("0.1.7")
204216
.arg_required_else_help(true)
217+
.about("Non-linear Dimension Reduction/Embedding via Approximate Nearest Neighbor Graph")
205218
.arg(
206219
Arg::new("csvfile")
207220
.long("csv")
@@ -311,7 +324,7 @@ pub fn main() {
311324
}
312325
}
313326
} else {
314-
hnswparams = HnswParams::my_default();
327+
hnswparams = HnswParams::default();
315328
}
316329
log::debug!("hnswparams : {:?}", hnswparams);
317330

@@ -433,13 +446,14 @@ where
433446
{
434447
//
435448
let nb_data = data_with_id.len();
436-
let hnsw = Hnsw::<f64, Dist>::new(
449+
let mut hnsw = Hnsw::<f64, Dist>::new(
437450
hnswparams.max_conn,
438451
nb_data,
439452
nb_layer,
440453
hnswparams.ef_c,
441454
Dist::default(),
442455
);
456+
hnsw.modify_level_scale(hnswparams.scale_modification);
443457
hnsw.parallel_insert(data_with_id);
444458
hnsw.dump_layer_info();
445459
let kgraph = kgraph_from_hnsw_all(&hnsw, hnswparams.knbn).unwrap();
@@ -496,13 +510,14 @@ where
496510
{
497511
//
498512
let nb_data = data_with_id.len();
499-
let hnsw = Hnsw::<f64, Dist>::new(
513+
let mut hnsw = Hnsw::<f64, Dist>::new(
500514
hnswparams.max_conn,
501515
nb_data,
502516
nb_layer,
503517
hnswparams.ef_c,
504518
Dist::default(),
505519
);
520+
hnsw.modify_level_scale(hnswparams.scale_modification);
506521
hnsw.parallel_insert(data_with_id);
507522
hnsw.dump_layer_info();
508523
KGraphProjection::<f64>::new(&hnsw, hnswparams.knbn, layer_proj)

0 commit comments

Comments
 (0)