Skip to content

Commit 4324b40

Browse files
wrapper: Add modular arithmetic functions to ArbitraryPrecisionInteger (#284)
1 parent f98e46e commit 4324b40

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed

Sources/CryptoBoringWrapper/Util/ArbitraryPrecisionInteger.swift

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,119 @@ extension ArbitraryPrecisionInteger: Numeric {
386386
}
387387
}
388388

389+
// MARK: - Modular arithmetic
390+
391+
extension ArbitraryPrecisionInteger {
392+
@usableFromInline
393+
package func modulo(_ mod: ArbitraryPrecisionInteger, nonNegative: Bool = false) throws -> ArbitraryPrecisionInteger {
394+
var result = ArbitraryPrecisionInteger()
395+
396+
let rc = result.withUnsafeMutableBignumPointer { resultPtr in
397+
self.withUnsafeBignumPointer { selfPtr in
398+
mod.withUnsafeBignumPointer { modPtr in
399+
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
400+
if nonNegative {
401+
CCryptoBoringSSL_BN_nnmod(resultPtr, selfPtr, modPtr, bnCtx)
402+
} else {
403+
CCryptoBoringSSLShims_BN_mod(resultPtr, selfPtr, modPtr, bnCtx)
404+
}
405+
}
406+
}
407+
}
408+
}
409+
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }
410+
411+
return result
412+
}
413+
414+
@usableFromInline
415+
package func inverse(modulo mod: ArbitraryPrecisionInteger) throws -> ArbitraryPrecisionInteger {
416+
var result = ArbitraryPrecisionInteger()
417+
418+
let rc = result.withUnsafeMutableBignumPointer { resultPtr in
419+
self.withUnsafeBignumPointer { selfPtr in
420+
mod.withUnsafeBignumPointer { modPtr in
421+
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
422+
CCryptoBoringSSL_BN_mod_inverse(resultPtr, selfPtr, modPtr, bnCtx)
423+
}
424+
}
425+
}
426+
}
427+
guard rc != nil else { throw CryptoBoringWrapperError.internalBoringSSLError() }
428+
429+
return result
430+
}
431+
432+
433+
@usableFromInline
434+
package static func inverse(lhs: ArbitraryPrecisionInteger, modulo mod: ArbitraryPrecisionInteger) throws -> ArbitraryPrecisionInteger {
435+
try ArbitraryPrecisionInteger(lhs).inverse(modulo: mod)
436+
}
437+
438+
@usableFromInline
439+
package func add(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger {
440+
guard let modulus else { return self + rhs }
441+
var result = ArbitraryPrecisionInteger()
442+
443+
let rc = result.withUnsafeMutableBignumPointer { resultPtr in
444+
self.withUnsafeBignumPointer { selfPtr in
445+
rhs.withUnsafeBignumPointer { rhsPtr in
446+
modulus.withUnsafeBignumPointer { modulusPtr in
447+
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
448+
return CCryptoBoringSSL_BN_mod_add(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx)
449+
}
450+
}
451+
}
452+
}
453+
}
454+
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }
455+
456+
return result
457+
}
458+
459+
@usableFromInline
460+
package func sub(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger {
461+
guard let modulus else { return self - rhs }
462+
var result = ArbitraryPrecisionInteger()
463+
464+
let rc = result.withUnsafeMutableBignumPointer { resultPtr in
465+
self.withUnsafeBignumPointer { selfPtr in
466+
rhs.withUnsafeBignumPointer { rhsPtr in
467+
modulus.withUnsafeBignumPointer { modulusPtr in
468+
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
469+
CCryptoBoringSSL_BN_mod_sub(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx)
470+
}
471+
}
472+
}
473+
}
474+
}
475+
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }
476+
477+
return result
478+
}
479+
480+
@usableFromInline
481+
package func mul(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger {
482+
guard let modulus else { return self * rhs }
483+
var result = ArbitraryPrecisionInteger()
484+
485+
let rc = result.withUnsafeMutableBignumPointer { resultPtr in
486+
self.withUnsafeBignumPointer { selfPtr in
487+
rhs.withUnsafeBignumPointer { rhsPtr in
488+
modulus.withUnsafeBignumPointer { modulusPtr in
489+
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
490+
return CCryptoBoringSSL_BN_mod_mul(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx)
491+
}
492+
}
493+
}
494+
}
495+
}
496+
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }
497+
498+
return result
499+
}
500+
}
501+
389502
// MARK: - SignedNumeric
390503

391504
extension ArbitraryPrecisionInteger: SignedNumeric {

Tests/CryptoBoringWrapperTests/ArbitraryPrecisionIntegerTests.swift

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,122 @@ final class ArbitraryPrecisionIntegerTests: XCTestCase {
165165
XCTAssertEqual(try ArbitraryPrecisionInteger(bytes: bytes), integer)
166166
}
167167
}
168+
169+
func testMoudlo() throws {
170+
typealias I = ArbitraryPrecisionInteger
171+
typealias Vector = (input: I, mod: I, expectedResult: (standard: I, nonNegative: I))
172+
for vector: Vector in [
173+
(input: 0, mod: 2, expectedResult: (standard: 0, nonNegative: 0)),
174+
(input: 1, mod: 2, expectedResult: (standard: 1, nonNegative: 1)),
175+
(input: 2, mod: 2, expectedResult: (standard: 0, nonNegative: 0)),
176+
(input: 3, mod: 2, expectedResult: (standard: 1, nonNegative: 1)),
177+
(input: 4, mod: 2, expectedResult: (standard: 0, nonNegative: 0)),
178+
(input: 5, mod: 2, expectedResult: (standard: 1, nonNegative: 1)),
179+
(input: 7, mod: 5, expectedResult: (standard: 2, nonNegative: 2)),
180+
(input: 7, mod: -5, expectedResult: (standard: 2, nonNegative: 2)),
181+
(input: -7, mod: 5, expectedResult: (standard: -2, nonNegative: 3)),
182+
(input: -7, mod: -5, expectedResult: (standard: -2, nonNegative: 3)),
183+
] {
184+
XCTAssertEqual(
185+
try vector.input.modulo(vector.mod, nonNegative: false),
186+
vector.expectedResult.standard,
187+
"\(vector.input) (mod \(vector.mod))"
188+
)
189+
XCTAssertEqual(
190+
try vector.input.modulo(vector.mod, nonNegative: true),
191+
vector.expectedResult.nonNegative,
192+
"\(vector.input) (nnmod \(vector.mod))"
193+
)
194+
}
195+
}
196+
197+
func testModularInverse() throws {
198+
typealias I = ArbitraryPrecisionInteger
199+
enum O { case ok(I), throwsError }
200+
typealias Vector = (a: I, mod: I, expectedResult: O)
201+
for vector: Vector in [
202+
(a: 3, mod: 7, expectedResult: .ok(5)),
203+
(a: 10, mod: 17, expectedResult: .ok(12)),
204+
(a: 7, mod: 26, expectedResult: .ok(15)),
205+
(a: 7, mod: 7, expectedResult: .throwsError),
206+
] {
207+
switch vector.expectedResult {
208+
case .ok(let expectedValue):
209+
XCTAssertEqual(try vector.a.inverse(modulo: vector.mod), expectedValue, "inverse(\(vector.a), modulo: \(vector.mod))")
210+
case .throwsError:
211+
XCTAssertThrowsError(try vector.a.inverse(modulo: vector.mod), "inverse(\(vector.a), modulo: \(vector.mod)")
212+
}
213+
}
214+
}
215+
216+
func testModularAddition() throws {
217+
typealias I = ArbitraryPrecisionInteger
218+
enum O { case ok(I), throwsError }
219+
typealias Vector = (a: I, b: I, mod: I, expectedResult: O)
220+
for vector: Vector in [
221+
(a: 0, b: 0, mod: 0, expectedResult: .throwsError),
222+
(a: 0, b: 0, mod: 2, expectedResult: .ok(0)),
223+
(a: 1, b: 0, mod: 2, expectedResult: .ok(1)),
224+
(a: 0, b: 1, mod: 2, expectedResult: .ok(1)),
225+
(a: 1, b: 1, mod: 2, expectedResult: .ok(0)),
226+
(a: 4, b: 3, mod: 5, expectedResult: .ok(2)),
227+
(a: 4, b: 3, mod: -5, expectedResult: .ok(2)),
228+
(a: -4, b: -3, mod: 5, expectedResult: .ok(3)),
229+
] {
230+
switch vector.expectedResult {
231+
case .ok(let expectedValue):
232+
XCTAssertEqual(try vector.a.add(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) + \(vector.b) (mod \(vector.mod))")
233+
case .throwsError:
234+
XCTAssertThrowsError(try vector.a.add(vector.b, modulo: vector.mod), "\(vector.a) + \(vector.b) (mod \(vector.mod))")
235+
}
236+
}
237+
}
238+
239+
func testModularSubtraction() throws {
240+
typealias I = ArbitraryPrecisionInteger
241+
enum O { case ok(I), throwsError }
242+
typealias Vector = (a: I, b: I, mod: I, expectedResult: O)
243+
for vector: Vector in [
244+
(a: 0, b: 0, mod: 0, expectedResult: .throwsError),
245+
(a: 0, b: 0, mod: 2, expectedResult: .ok(0)),
246+
(a: 1, b: 0, mod: 2, expectedResult: .ok(1)),
247+
(a: 0, b: 1, mod: 2, expectedResult: .ok(1)),
248+
(a: 1, b: 1, mod: 2, expectedResult: .ok(0)),
249+
(a: 4, b: 3, mod: 5, expectedResult: .ok(1)),
250+
(a: 3, b: 4, mod: 5, expectedResult: .ok(4)),
251+
(a: 3, b: 4, mod: -5, expectedResult: .ok(4)),
252+
(a: -3, b: 4, mod: 5, expectedResult: .ok(3)),
253+
(a: 3, b: -4, mod: 5, expectedResult: .ok(2)),
254+
] {
255+
switch vector.expectedResult {
256+
case .ok(let expectedValue):
257+
XCTAssertEqual(try vector.a.sub(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) - \(vector.b) (mod \(vector.mod))")
258+
case .throwsError:
259+
XCTAssertThrowsError(try vector.a.sub(vector.b, modulo: vector.mod), "\(vector.a) - \(vector.b) (mod \(vector.mod))")
260+
}
261+
}
262+
}
263+
264+
func testModularMultiplication() throws {
265+
typealias I = ArbitraryPrecisionInteger
266+
enum O { case ok(I), throwsError }
267+
typealias Vector = (a: I, b: I, mod: I, expectedResult: O)
268+
for vector: Vector in [
269+
(a: 0, b: 0, mod: 0, expectedResult: .throwsError),
270+
(a: 0, b: 0, mod: 2, expectedResult: .ok(0)),
271+
(a: 1, b: 0, mod: 2, expectedResult: .ok(0)),
272+
(a: 0, b: 1, mod: 2, expectedResult: .ok(0)),
273+
(a: 1, b: 1, mod: 2, expectedResult: .ok(1)),
274+
(a: 4, b: 3, mod: 5, expectedResult: .ok(2)),
275+
(a: 4, b: 3, mod: -5, expectedResult: .ok(2)),
276+
(a: -4, b: -3, mod: 5, expectedResult: .ok(2)),
277+
] {
278+
switch vector.expectedResult {
279+
case .ok(let expectedValue):
280+
XCTAssertEqual(try vector.a.mul(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) × \(vector.b) (mod \(vector.mod))")
281+
case .throwsError:
282+
XCTAssertThrowsError(try vector.a.mul(vector.b, modulo: vector.mod), "\(vector.a) × \(vector.b) (mod \(vector.mod))")
283+
}
284+
}
285+
}
168286
}

0 commit comments

Comments
 (0)