Skip to content

Commit 6ab4e1c

Browse files
committed
tweak(vrf): Optimize to_group by avoiding the recomputation of constants
1 parent f5c17bb commit 6ab4e1c

File tree

1 file changed

+49
-29
lines changed

1 file changed

+49
-29
lines changed

vrf/src/message.rs

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::OnceLock;
2+
13
use ark_ff::{One, SquareRootField, Zero};
24

35
use ledger::{proofs::transaction::legacy_input::to_bits, ToInputs};
@@ -18,6 +20,25 @@ pub struct VrfMessage {
1820
delegator_index: u64,
1921
}
2022

23+
struct CachedFields {
24+
two: BaseField,
25+
three: BaseField,
26+
five: BaseField,
27+
projection_point_z: BaseField,
28+
}
29+
30+
static CACHED_FIELDS: OnceLock<CachedFields> = OnceLock::new();
31+
32+
#[inline(always)]
33+
fn get_y(x: BaseField, five: BaseField) -> Option<BaseField> {
34+
let mut res = x;
35+
res *= &x; // x^2
36+
res += BaseField::zero(); // x^2 + A x
37+
res *= &x; // x^3 + A x
38+
res += five; // x^3 + A x + B
39+
res.sqrt()
40+
}
41+
2142
impl VrfMessage {
2243
pub fn new(global_slot: u32, epoch_seed: EpochSeed, delegator_index: u64) -> Self {
2344
Self {
@@ -32,27 +53,40 @@ impl VrfMessage {
3253
}
3354

3455
pub fn to_group(&self) -> VrfResult<CurvePoint> {
35-
// helpers
36-
let two = BaseField::one() + BaseField::one();
37-
let three = two + BaseField::one();
38-
39-
// params, according to ocaml
40-
let mut projection_point_z_bytes =
41-
hex::decode("1AF731EC3CA2D77CC5D13EDC8C9A0A77978CB5F4FBFCC470B5983F5B6336DB69")?;
42-
projection_point_z_bytes.reverse();
43-
let projection_point_z = BaseField::from_bytes(&projection_point_z_bytes)?;
56+
let cached = CACHED_FIELDS.get_or_init(|| {
57+
let one = BaseField::one();
58+
let two = one + one;
59+
let three = two + one;
60+
let five = three + two;
61+
62+
// according to ocaml
63+
let mut projection_point_z_bytes =
64+
hex::decode("1AF731EC3CA2D77CC5D13EDC8C9A0A77978CB5F4FBFCC470B5983F5B6336DB69")
65+
.expect("Failed to decode hex string");
66+
projection_point_z_bytes.reverse();
67+
let projection_point_z = BaseField::from_bytes(&projection_point_z_bytes)
68+
.expect("Failed to convert bytes to BaseField");
69+
70+
CachedFields {
71+
two,
72+
three,
73+
five,
74+
projection_point_z,
75+
}
76+
});
77+
4478
let projection_point_y = BaseField::one();
45-
let conic_c = three;
79+
let conic_c = cached.three;
4680
let u_over_2 = BaseField::one();
47-
let u = two;
81+
let u = cached.two;
4882

4983
let t = self.hash();
5084

5185
// field to conic
5286
let ct = conic_c * t;
53-
let s =
54-
two * ((ct * projection_point_y) + projection_point_z) / ((ct * t) + BaseField::one());
55-
let conic_z = projection_point_z - s;
87+
let s = cached.two * ((ct * projection_point_y) + cached.projection_point_z)
88+
/ ((ct * t) + BaseField::one());
89+
let conic_z = cached.projection_point_z - s;
5690
let conic_y = projection_point_y - (s * t);
5791

5892
// conic to s
@@ -64,22 +98,8 @@ impl VrfMessage {
6498
let x2 = -(u + v);
6599
let x3 = u + (y * y);
66100

67-
let get_y = |x: BaseField| -> Option<BaseField> {
68-
let five = BaseField::one()
69-
+ BaseField::one()
70-
+ BaseField::one()
71-
+ BaseField::one()
72-
+ BaseField::one();
73-
let mut res = x;
74-
res *= &x; // x^2
75-
res += BaseField::zero(); // x^2 + A x
76-
res *= &x; // x^3 + A x
77-
res += five; // x^3 + A x + B
78-
res.sqrt()
79-
};
80-
81101
for x in [x1, x2, x3] {
82-
if let Some(y) = get_y(x) {
102+
if let Some(y) = get_y(x, cached.five) {
83103
return Ok(CurvePoint::new(x, y, false));
84104
}
85105
}

0 commit comments

Comments
 (0)