Skip to content

Commit f5599b7

Browse files
committed
SIMD-284: alt_bn128 little endian mode
1 parent 20ffa22 commit f5599b7

File tree

8 files changed

+533
-450
lines changed

8 files changed

+533
-450
lines changed

src/core/features.zon

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,5 @@
252252
.{ .name = "fix_alt_bn128_pairing_length_check", .pubkey = "bnYzodLwmybj7e1HAe98yZrdJTd7we69eMMLgCXqKZm" },
253253
.{ .name = "increase_cpi_account_info_limit", .pubkey = "H6iVbVaDZgDphcPbcZwc5LoznMPWQfnJ1AM7L1xzqvt5" },
254254
.{ .name = "vote_state_v4", .pubkey = "Gx4XFcrVMt4HUvPzTpTSVkdDVgcDSjKhDN1RqRS6KDuZ" },
255+
.{ .name = "alt_bn128_little_endian", .pubkey = "bnS3pWfLrxHRJvMyLm6EaYQkP7A2Fe9DxoKv4aGA8YM" },
255256
}

src/crypto/bn254/fields.zig

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,25 @@ pub const Fp = struct {
8686

8787
pub fn fromBytes(
8888
input: *const [32]u8,
89+
endian: std.builtin.Endian,
8990
maybe_flags: ?*Flags,
9091
) !Fp {
9192
if (maybe_flags) |flags| {
92-
flags.* = @bitCast(input[0]);
93+
const offset: u32 = switch (endian) {
94+
.big => 0,
95+
.little => 31,
96+
};
97+
flags.* = @bitCast(input[offset]);
9398
// If both flags are set, return an error.
9499
// https://github.com/arkworks-rs/algebra/blob/v0.4.2/ec/src/models/short_weierstrass/serialization_flags.rs#L75
95100
if (flags.is_inf and flags.is_neg) return error.BothFlags;
96101
}
97102

98-
var limbs: [32]u8 = byteSwap(input.*);
103+
var limbs: [32]u8 = switch (endian) {
104+
.big => byteSwap(input.*),
105+
.little => input.*,
106+
};
107+
// NOTE: We perform the mask *after* the byteSwap, so we don't need to select the offset for the mask again.
99108
if (maybe_flags != null) limbs[31] &= Flags.MASK;
100109

101110
// Check that we've decoded a valid field element.
@@ -105,8 +114,11 @@ pub const Fp = struct {
105114
return .{ .limbs = @bitCast(limbs) };
106115
}
107116

108-
pub fn toBytes(f: Fp, out: *[32]u8) void {
109-
out.* = byteSwap(@bitCast(f.limbs));
117+
pub fn toBytes(f: Fp, out: *[32]u8, endian: std.builtin.Endian) void {
118+
out.* = switch (endian) {
119+
.little => @bitCast(f.limbs),
120+
.big => byteSwap(@bitCast(f.limbs)),
121+
};
110122
}
111123

112124
pub fn byteSwap(a: [32]u8) [32]u8 {
@@ -120,6 +132,7 @@ pub const Fp = struct {
120132
return @bitCast(array);
121133
}
122134

135+
/// Well-defined for both montgomery and normal form.
123136
pub fn isZero(f: Fp) bool {
124137
return f.eql(.zero);
125138
}
@@ -384,16 +397,24 @@ pub const Fp2 = struct {
384397
} };
385398
};
386399

387-
pub fn fromBytes(input: *const [64]u8, maybe_flags: ?*Flags) !Fp2 {
400+
pub fn fromBytes(input: *const [64]u8, endian: std.builtin.Endian, maybe_flags: ?*Flags) !Fp2 {
401+
const el0: u32, const el1: u32 = switch (endian) {
402+
.little => .{ 0, 32 },
403+
.big => .{ 32, 0 },
404+
};
388405
return .{
389-
.c0 = try .fromBytes(input[32..64], null),
390-
.c1 = try .fromBytes(input[0..32], maybe_flags),
406+
.c0 = try .fromBytes(input[el0..][0..32], endian, null),
407+
.c1 = try .fromBytes(input[el1..][0..32], endian, maybe_flags),
391408
};
392409
}
393410

394-
pub fn toBytes(f: Fp2, out: *[64]u8) void {
395-
f.c0.toBytes(out[32..64]);
396-
f.c1.toBytes(out[0..32]);
411+
pub fn toBytes(f: Fp2, out: *[64]u8, endian: std.builtin.Endian) void {
412+
const el0: u32, const el1: u32 = switch (endian) {
413+
.little => .{ 0, 32 },
414+
.big => .{ 32, 0 },
415+
};
416+
f.c0.toBytes(out[el0..][0..32], endian);
417+
f.c1.toBytes(out[el1..][0..32], endian);
397418
}
398419

399420
pub fn isZero(f: Fp2) bool {

src/crypto/bn254/lib.zig

Lines changed: 63 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,12 @@ pub const G1 = struct {
2626
.z = .zero,
2727
};
2828

29-
fn fromBytesInternal(input: *const [64]u8) !G1 {
29+
fn fromBytesInternal(input: *const [64]u8, endian: std.builtin.Endian) !G1 {
3030
if (std.mem.allEqual(u8, input, 0)) return .zero;
31-
3231
var flags: Flags = undefined;
3332
return .{
34-
.x = try .fromBytes(input[0..32], null),
35-
.y = try .fromBytes(input[32..64], &flags),
33+
.x = try .fromBytes(input[0..32], endian, null),
34+
.y = try .fromBytes(input[32..64], endian, &flags),
3635
.z = if (flags.is_inf) .zero else .one,
3736
};
3837
}
@@ -49,8 +48,8 @@ pub const G1 = struct {
4948
// G1 has prime order so we do not need a subgroup membership check.
5049
}
5150

52-
pub fn fromBytes(input: *const [64]u8) !G1 {
53-
var g1 = try fromBytesInternal(input);
51+
pub fn fromBytes(input: *const [64]u8, endian: std.builtin.Endian) !G1 {
52+
var g1 = try fromBytesInternal(input, endian);
5453
if (g1.isZero()) return g1;
5554

5655
g1.x.toMont();
@@ -62,19 +61,17 @@ pub const G1 = struct {
6261
return g1;
6362
}
6463

65-
fn toBytes(p: G1, out: *[64]u8) void {
64+
fn toBytes(p: G1, out: *[64]u8, endian: std.builtin.Endian) void {
6665
if (p.isZero()) {
67-
// no flags
68-
@memset(out, 0);
66+
@memset(out, 0); // no flags
6967
return;
7068
}
7169

7270
var r = p.toAffine();
7371
r.x.fromMont();
7472
r.y.fromMont();
75-
76-
r.x.toBytes(out[0..32]);
77-
r.y.toBytes(out[32..64]);
73+
r.x.toBytes(out[0..32], endian);
74+
r.y.toBytes(out[32..64], endian);
7875
}
7976

8077
fn isZero(p: G1) bool {
@@ -99,8 +96,8 @@ pub const G1 = struct {
9996
};
10097
}
10198

102-
pub fn compress(out: *[32]u8, input: *const [64]u8) !void {
103-
const p: G1 = try .fromBytesInternal(input);
99+
pub fn compress(out: *[32]u8, input: *const [64]u8, endian: std.builtin.Endian) !void {
100+
const p: G1 = try .fromBytesInternal(input, endian);
104101

105102
const is_inf = p.isZero();
106103
const flag_inf = input[32] & Flags.INF;
@@ -115,33 +112,37 @@ pub const G1 = struct {
115112

116113
const is_neg = p.y.isNegative();
117114
@memcpy(out, input[0..32]);
118-
if (is_neg) out[0] |= Flags.NEG;
115+
const offset: u32 = switch (endian) {
116+
.little => 31,
117+
.big => 0,
118+
};
119+
if (is_neg) out[offset] |= Flags.NEG;
119120
return;
120121
}
121122

122-
pub fn decompress(out: *[64]u8, input: *const [32]u8) !void {
123+
pub fn decompress(out: *[64]u8, input: *const [32]u8, endian: std.builtin.Endian) !void {
123124
// All zeroes input, all zeroes out, no flags.
124125
if (std.mem.allEqual(u8, input, 0)) return @memset(out, 0);
125126

126127
var flags: Flags = undefined;
127-
var x: Fp = try .fromBytes(input, &flags);
128+
const x: Fp = try .fromBytes(input, endian, &flags);
128129

129130
// If the point at infinity flag is set, return the point at infinity without any
130131
// checks on the coordinates (X, Y) and no flags set.
131132
if (flags.is_inf) return @memset(out, 0);
132133

134+
var xm = x;
135+
xm.toMont();
133136
// y^2 = x^3+b
134-
x.toMont();
135-
const x3b = x.sq().mul(x).add(Fp.constants.b_mont);
137+
const x3b = xm.sq().mul(xm).add(Fp.constants.b_mont);
136138
var y = try x3b.sqrt();
137139
y.fromMont();
138140
if (flags.is_neg != y.isNegative()) {
139141
y.negateNotMontgomery(y); // correct the sign to the requested one
140142
}
141143

142-
@memcpy(out[0..32], input);
143-
out[0] &= Flags.MASK;
144-
y.toBytes(out[32..64]);
144+
x.toBytes(out[0..32], endian);
145+
y.toBytes(out[32..64], endian);
145146
// no flags on y
146147
}
147148

@@ -198,13 +199,13 @@ pub const G2 = struct {
198199
.z = .zero,
199200
};
200201

201-
fn fromBytesInternal(input: *const [128]u8) !G2 {
202+
fn fromBytesInternal(input: *const [128]u8, endian: std.builtin.Endian) !G2 {
202203
if (std.mem.allEqual(u8, input, 0)) return .zero;
203204

204205
var flags: Flags = undefined;
205206
return .{
206-
.x = try .fromBytes(input[0..64], null),
207-
.y = try .fromBytes(input[64..128], &flags),
207+
.x = try .fromBytes(input[0..64], endian, null),
208+
.y = try .fromBytes(input[64..128], endian, &flags),
208209
.z = if (flags.is_inf) .zero else .one,
209210
};
210211
}
@@ -232,8 +233,8 @@ pub const G2 = struct {
232233
if (!l.eql(r)) return error.NotWellFormed;
233234
}
234235

235-
fn fromBytes(input: *const [128]u8) !G2 {
236-
var g2: G2 = try .fromBytesInternal(input);
236+
fn fromBytes(input: *const [128]u8, endian: std.builtin.Endian) !G2 {
237+
var g2: G2 = try .fromBytesInternal(input, endian);
237238
if (g2.isZero()) return g2;
238239

239240
g2.x.toMont();
@@ -249,8 +250,8 @@ pub const G2 = struct {
249250
return p.z.isZero();
250251
}
251252

252-
pub fn compress(out: *[64]u8, input: *const [128]u8) !void {
253-
const p: G2 = try .fromBytesInternal(input);
253+
pub fn compress(out: *[64]u8, input: *const [128]u8, endian: std.builtin.Endian) !void {
254+
const p: G2 = try .fromBytesInternal(input, endian);
254255

255256
const is_inf = p.isZero();
256257
const flag_inf = input[64] & Flags.INF;
@@ -263,33 +264,37 @@ pub const G2 = struct {
263264
}
264265

265266
const is_neg = p.y.isNegative();
266-
@memcpy(out, input[0..64]);
267-
if (is_neg) out[0] |= Flags.NEG;
267+
p.x.toBytes(out, endian);
268+
const offset: u32 = switch (endian) {
269+
.little => 63,
270+
.big => 0,
271+
};
272+
if (is_neg) out[offset] |= Flags.NEG;
268273
return;
269274
}
270275

271-
pub fn decompress(out: *[128]u8, input: *const [64]u8) !void {
276+
pub fn decompress(out: *[128]u8, input: *const [64]u8, endian: std.builtin.Endian) !void {
272277
if (std.mem.allEqual(u8, input, 0)) return @memset(out, 0);
273278

274279
var flags: Flags = undefined;
275-
var x: Fp2 = try .fromBytes(input, &flags);
280+
const x: Fp2 = try .fromBytes(input, endian, &flags);
276281

277282
// no flags
278283
if (flags.is_inf) return @memset(out, 0);
279284

280285
// y^2 = x^3+b
281-
x.toMont();
282-
const x3b = x.sq().mul(x).add(Fp2.constants.twist_b_mont);
286+
var xm = x;
287+
xm.toMont();
288+
const x3b = xm.sq().mul(xm).add(Fp2.constants.twist_b_mont);
283289
var y = try x3b.sqrt();
284290

285291
y.fromMont();
286292
if (flags.is_neg != y.isNegative()) {
287293
y.negateNotMontgomery(y);
288294
}
289295

290-
@memcpy(out[0..64], input);
291-
out[0] &= Flags.MASK;
292-
y.toBytes(out[64..128]);
296+
x.toBytes(out[0..64], endian);
297+
y.toBytes(out[64..128], endian);
293298
}
294299

295300
fn eql(a: G2, b: G2) bool {
@@ -479,31 +484,34 @@ fn mulScalar(a: anytype, scalar: u256) @TypeOf(a) {
479484
return r;
480485
}
481486

482-
pub fn addSyscall(out: *[64]u8, input: *const [128]u8) !void {
483-
const x: G1 = try .fromBytes(input[0..64]);
484-
const y: G1 = try .fromBytes(input[64..128]);
487+
pub fn addSyscall(out: *[64]u8, input: *const [128]u8, endian: std.builtin.Endian) !void {
488+
const x: G1 = try .fromBytes(input[0..64], endian);
489+
const y: G1 = try .fromBytes(input[64..128], endian);
485490
const result = x.affineAdd(y);
486-
result.toBytes(out);
491+
result.toBytes(out, endian);
487492
}
488493

489-
pub fn mulSyscall(out: *[64]u8, input: *const [96]u8) !void {
490-
const a: G1 = try .fromBytes(input[0..64]);
494+
pub fn mulSyscall(out: *[64]u8, input: *const [96]u8, endian: std.builtin.Endian) !void {
495+
const a: G1 = try .fromBytes(input[0..64], endian);
491496
// Scalar is provided in big-endian and we do *not* validate it.
492-
const b: u256 = @bitCast(Fp.byteSwap(input[64..][0..32].*));
497+
const b: u256 = @bitCast(switch (endian) {
498+
.big => Fp.byteSwap(input[64..][0..32].*),
499+
.little => input[64..][0..32].*,
500+
});
493501
const result = mulScalar(a, b);
494-
result.toBytes(out);
502+
result.toBytes(out, endian);
495503
}
496504

497-
pub fn pairingSyscall(out: *[32]u8, input: []const u8) !void {
505+
pub fn pairingSyscall(out: *[32]u8, input: []const u8, endian: std.builtin.Endian) !void {
498506
const num_elements = input.len / 192;
499507

500508
var p: std.BoundedArray(G1, pairing.BATCH_SIZE) = .{};
501509
var q: std.BoundedArray(G2, pairing.BATCH_SIZE) = .{};
502510

503511
var r: Fp12 = .one;
504512
for (0..num_elements) |i| {
505-
const a: G1 = try .fromBytes(input[i * 192 ..][0..64]);
506-
const b: G2 = try .fromBytes(input[i * 192 ..][64..][0..128]);
513+
const a: G1 = try .fromBytes(input[i * 192 ..][0..64], endian);
514+
const b: G2 = try .fromBytes(input[i * 192 ..][64..][0..128], endian);
507515

508516
// Skip any pair where either A or B are points at infinity.
509517
if (a.isZero() or b.isZero()) continue;
@@ -521,7 +529,11 @@ pub fn pairingSyscall(out: *[32]u8, input: []const u8) !void {
521529
}
522530

523531
r = pairing.finalExp(r);
524-
// Output is 0 or 1 as a big-endian u256.
532+
// Output is 0 or 1 as a u256.
525533
@memset(out, 0);
526-
if (r.isOne()) out[31] = 1;
534+
const offset: u32 = switch (endian) {
535+
.little => 0,
536+
.big => 31,
537+
};
538+
if (r.isOne()) out[offset] = 1;
527539
}

0 commit comments

Comments
 (0)