Skip to content

Commit bb3c91c

Browse files
committed
Move multiplyAndDivide to common math helpers
- Introduce a specialized pair of Long class DivRemResult - Avoid exception on long multiplication path - Add more tests for multiplyAndDivide
1 parent 2a326ef commit bb3c91c

File tree

4 files changed

+227
-157
lines changed

4 files changed

+227
-157
lines changed

core/commonMain/src/math.kt

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,152 @@ internal fun Long.clampToInt(): Int =
1212
else -> toInt()
1313
}
1414

15+
internal const val NANOS_PER_MILLI = 1_000_000
16+
internal const val MILLIS_PER_ONE = 1_000
17+
internal const val NANOS_PER_ONE = 1_000_000_000
1518

1619
internal expect fun safeMultiply(a: Long, b: Long): Long
1720
internal expect fun safeMultiply(a: Int, b: Int): Int
1821
internal expect fun safeAdd(a: Long, b: Long): Long
1922
internal expect fun safeAdd(a: Int, b: Int): Int
23+
24+
/** Multiplies two non-zero long values. */
25+
internal fun safeMultiplyOrZero(a: Long, b: Long): Long {
26+
when (b) {
27+
-1L -> {
28+
if (a == Long.MIN_VALUE) {
29+
return 0L
30+
}
31+
return -a
32+
}
33+
1L -> return a
34+
}
35+
val total = a * b
36+
if (total / b != a) {
37+
return 0L
38+
}
39+
return total
40+
}
41+
42+
/**
43+
* Calculates [a] * [b] / [c]. Returns a pair of the quotient and the remainder.
44+
* [c] must be greater than zero.
45+
*
46+
* @throws ArithmeticException if the result overflows a long
47+
*/
48+
internal fun multiplyAndDivide(a: Long, b: Long, c: Long): DivRemResult {
49+
if (a == 0L || b == 0L) return DivRemResult(0, 0)
50+
val ab = safeMultiplyOrZero(a, b)
51+
if (ab != 0L) return DivRemResult(ab / c, ab % c)
52+
53+
/* Not just optimizations: this is needed for multiplyAndDivide(Long.MIN_VALUE, x, x) to work. */
54+
if (b == c) return DivRemResult(a, 0)
55+
if (a == c) return DivRemResult(b, 0)
56+
57+
58+
/* a * b = (ae * 2^64 + ah * 2^32 + al) * (be * 2^64 + bh * 2^32 + bl)
59+
= ae * be * 2^128 + (ae * bh + ah * be) * 2^96 + (ae * bl + ah * bh + al * be) * 2^64
60+
+ (ah * bl + al * bh) * 2^32 + al * bl
61+
= 0 + w * 2^96 + x * 2^64 + y * 2^32 + z = xh * 2^96 + (xl + yh) * 2^64 + (yl + zh) * 2^32 + zl
62+
= r1 * 2^96 | r2 * 2^64 | r3 * 2^32 | r4
63+
= abh * 2^64 | abl */
64+
// a, b in [0; 2^64)
65+
66+
// sign extensions to 128 bits:
67+
val ae = if (a >= 0) 0L else -1L // all ones or all zeros
68+
val be = if (b >= 0) 0L else -1L // all ones or all zeros
69+
70+
val al = low(a) // [0; 2^32)
71+
val ah = high(a) // [0; 2^32)
72+
val bl = low(b) // [0; 2^32)
73+
val bh = high(b) // [0; 2^32)
74+
75+
/* even though the language operates on signed Long values, we can add and multiply them as if they were unsigned
76+
due to the fact that they are encoded as 2's complement (hence the need to use sign extensions). The only operation
77+
here where sign matters is division. */
78+
val w = ae * bh + ah * be // we will only use the lower 32 bits of this value, so overflow is fine
79+
val x = ae * bl + ah * bh + al * be // may overflow, but overflow here goes beyond 128 bit
80+
val y1 = ah * bl
81+
val y2 = al * bh // y is split into y1 and y2 because y1 + y2 may overflow 2^64, which loses information
82+
val z = al * bl
83+
84+
val r4 = low(z)
85+
val r3c = low(y1) + low(y2) + high(z)
86+
val r3 = low(r3c)
87+
val r2c = high(r3c) + low(x) + high(y1) + high(y2)
88+
val r2 = low(r2c)
89+
/* If r1 overflows 2^32 - 1, it's because of sign extension: we don't lose any significant bits because multiplying
90+
[0; 2^64) by [0; 2^64) may never exceed 2^128 - 1. */
91+
val r1 = high(r2c) + high(x) + low(w)
92+
93+
var abl = (r3 shl 32) or r4 // low 64 bits of a * b
94+
var abh = (r1 shl 32) or r2 // high 64 bits of a * b
95+
96+
97+
val sign = if (indexBit(abh, 63) == 1L) -1 else 1
98+
99+
if (sign == -1) {
100+
// negate, so that we operate on a positive number
101+
abl = abl.inv() + 1
102+
abh = abh.inv()
103+
if (abl == 0L) // abl overflowed
104+
abh += 1
105+
}
106+
107+
/* The resulting quotient. This division is unsigned, so if the result doesn't fit in 63 bits, it means that
108+
overflow occurred. */
109+
var q = 0L
110+
// The remainder, always less than c and so fits in a Long.
111+
var r = 0L
112+
// Simple long division algorithm
113+
for (bitNo in 127 downTo 0) {
114+
// bit #bitNo of the numerator
115+
val nextBit = if (bitNo < 64) indexBit(abl, bitNo) else indexBit(abh, bitNo - 64)
116+
// left-shift R by one bit, setting the least significant bit to nextBit
117+
r = (r shl 1) or nextBit
118+
// if (R >= c). If R < 0, R >= 2^63 > Long.MAX_VALUE >= c
119+
if (r >= c || r < 0) {
120+
r -= c
121+
// set bit #bitNo of Q to 1
122+
if (bitNo < 63)
123+
q = q or (1L shl bitNo)
124+
else
125+
throw ArithmeticException("The result of a multiplication followed by division overflows a long")
126+
}
127+
}
128+
return DivRemResult(sign * q, sign * r)
129+
}
130+
131+
internal class DivRemResult(val q: Long, val r: Long) {
132+
operator fun component1(): Long = q
133+
operator fun component2(): Long = r
134+
}
135+
136+
private inline fun low(x: Long) = x and 0xffffffff
137+
private inline fun high(x: Long) = (x shr 32) and 0xffffffff
138+
/** For [bit] in [0; 63], return bit #[bit] of [value], counting from the least significant bit */
139+
private inline fun indexBit(value: Long, bit: Int): Long = (value shr bit and 1)
140+
141+
142+
/**
143+
* Calculates ([d] * [n] + [r]) / [m], where [n], [m] > 0 and |[r]| <= [n].
144+
*
145+
* @throws ArithmeticException if the result overflows a long
146+
*/
147+
internal fun multiplyAddAndDivide(d: Long, n: Long, r: Long, m: Long): Long {
148+
var md = d
149+
var mr = r
150+
// make sure [md] and [mr] have the same sign
151+
if (d > 0 && r < 0) {
152+
md--
153+
mr += n
154+
} else if (d < 0 && r > 0) {
155+
md++
156+
mr -= n
157+
}
158+
if (md == 0L) {
159+
return mr / m
160+
}
161+
val (rd, rr) = multiplyAndDivide(md, n, m)
162+
return safeAdd(rd, safeAdd(mr / m, safeAdd(mr % m, rr) / m))
163+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright 2019-2020 JetBrains s.r.o.
3+
* Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file.
4+
*/
5+
6+
package kotlinx.datetime.test.math
7+
import kotlin.random.*
8+
import kotlin.test.*
9+
import kotlinx.datetime.*
10+
11+
class MultiplyAndDivideTest {
12+
13+
private fun mulDiv(a: Long, b: Long, m: Long): Pair<Long, Long> =
14+
multiplyAndDivide(a, b, m).run { q to r }
15+
16+
@Test
17+
fun small() {
18+
assertEquals(4L to 3L, mulDiv(5L, 15L, 18L))
19+
assertEquals(15L to 0L, mulDiv(5L, 12L, 4L))
20+
}
21+
22+
@Test
23+
fun smallNegative() {
24+
assertEquals(4L to 3L, mulDiv(-5L, -15L, 18L))
25+
assertEquals(597308323L to 475144067L, mulDiv(-1057588554, -1095571653, 1939808965))
26+
}
27+
28+
@Test
29+
fun large() {
30+
val l = Long.MAX_VALUE
31+
val result = mulDiv(l - 1, l - 2, l)
32+
assertEquals(9223372036854775804L to 2L, result) // https://www.wolframalpha.com/input/?i=floor%28%282%5E63+-+2%29+*+%282%5E63+-+3%29+%2F+%282%5E63+-1%29%29
33+
}
34+
35+
@Test
36+
fun largeNegative() {
37+
val r1 = mulDiv(Long.MIN_VALUE, Long.MAX_VALUE, Long.MAX_VALUE)
38+
val r2 = mulDiv(Long.MAX_VALUE, Long.MIN_VALUE, Long.MAX_VALUE)
39+
assertEquals(Long.MIN_VALUE to 0L, r1)
40+
assertEquals(r1, r2)
41+
42+
val r3 = mulDiv(Long.MIN_VALUE, Long.MAX_VALUE - 1, Long.MAX_VALUE)
43+
val r4 = mulDiv(Long.MAX_VALUE - 1, Long.MIN_VALUE, Long.MAX_VALUE)
44+
45+
assertEquals(-9223372036854775806 to -9223372036854775806, r3)
46+
assertEquals(r3, r4)
47+
}
48+
49+
@Test
50+
fun halfLarge() {
51+
val l = Long.MAX_VALUE / 2 + 1
52+
val result = mulDiv(l - 2, l - 3, l)
53+
assertEquals(4611686018427387899L to 6L, result)
54+
}
55+
56+
@Test
57+
fun randomProductFitsInLong() {
58+
repeat(1000) {
59+
val a = Random.nextInt().toLong()
60+
val b = Random.nextInt().toLong()
61+
val m = Random.nextInt(1, Int.MAX_VALUE).toLong()
62+
// println("$a, $b, $c: ${a * b / c}, ${a * b % c}")
63+
val (q, r) = mulDiv(a, b, m)
64+
// println("$d, $r")
65+
assertEquals(a * b / m, q)
66+
assertEquals(a * b % m, r)
67+
}
68+
}
69+
70+
71+
@Test
72+
fun nearIntBoundary() {
73+
val (d, r) = mulDiv((1L shl 32) + 1693, (1L shl 32) - 3, (1L shl 33) - 1)
74+
assertEquals(2_147_484_493, d)
75+
assertEquals(2_147_479_414, r)
76+
}
77+
78+
@Test
79+
fun largeOverflows() {
80+
assertFailsWith<ArithmeticException> { mulDiv(Long.MIN_VALUE, Long.MIN_VALUE, Long.MAX_VALUE) }
81+
assertFailsWith<ArithmeticException> { mulDiv(Long.MAX_VALUE, 4, 3) }
82+
}
83+
}

core/nativeMain/src/Util.kt renamed to core/nativeMain/src/mathNative.kt

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -9,130 +9,10 @@ package kotlinx.datetime
99

1010
import kotlin.math.abs
1111

12-
/**
13-
* Calculates [a] * [b] / [c]. Returns a pair of the quotient and the remainder.
14-
* [c] must be greater than zero.
15-
*
16-
* @throws ArithmeticException if the result overflows a long
17-
*/
18-
internal fun multiplyAndDivide(a: Long, b: Long, c: Long): Pair<Long, Long> {
19-
try {
20-
return safeMultiply(a, b).let { Pair(it / c, it % c) }
21-
} catch (e: ArithmeticException) {
22-
// this body is intentionally left blank
23-
}
24-
/* Not just optimizations: this is needed for multiplyAndDivide(Long.MIN_VALUE, x, x) to work. */
25-
if (b == c) return Pair(a, 0)
26-
if (a == c) return Pair(b, 0)
27-
28-
inline fun low(x: Long) = x and 0xffffffff
29-
inline fun high(x: Long) = (x shr 32) and 0xffffffff
30-
31-
/* a * b = (ae * 2^64 + ah * 2^32 + al) * (be * 2^64 + bh * 2^32 + bl)
32-
= ae * be * 2^128 + (ae * bh + ah * be) * 2^96 + (ae * bl + ah * bh + al * be) * 2^64
33-
+ (ah * bl + al * bh) * 2^32 + al * bl
34-
= 0 + w * 2^96 + x * 2^64 + y * 2^32 + z = xh * 2^96 + (xl + yh) * 2^64 + (yl + zh) * 2^32 + zl
35-
= r1 * 2^96 | r2 * 2^64 | r3 * 2^32 | r4
36-
= abh * 2^64 | abl */
37-
// a, b in [0; 2^64)
38-
39-
// sign extensions to 128 bits:
40-
val ae = if (a >= 0) 0L else -1L // all ones or all zeros
41-
val be = if (b >= 0) 0L else -1L // all ones or all zeros
42-
43-
val al = low(a) // [0; 2^32)
44-
val ah = high(a) // [0; 2^32)
45-
val bl = low(b) // [0; 2^32)
46-
val bh = high(b) // [0; 2^32)
47-
48-
/* even though the language operates on signed Long values, we can add and multiply them as if they were unsigned
49-
due to the fact that they are encoded as 2's complement (hence the need to use sign extensions). The only operation
50-
here where sign matters is division. */
51-
val w = ae * bh + ah * be // we will only use the lower 32 bits of this value, so overflow is fine
52-
val x = ae * bl + ah * bh + al * be // may overflow, but overflow here goes beyond 128 bit
53-
val y1 = ah * bl
54-
val y2 = al * bh // y is split into y1 and y2 because y1 + y2 may overflow 2^64, which loses information
55-
val z = al * bl
56-
57-
val r4 = low(z)
58-
val r3c = low(y1) + low(y2) + high(z)
59-
val r3 = low(r3c)
60-
val r2c = high(r3c) + low(x) + high(y1) + high(y2)
61-
val r2 = low(r2c)
62-
/* If r1 overflows 2^32 - 1, it's because of sign extension: we don't lose any significant bits because multiplying
63-
[0; 2^64) by [0; 2^64) may never exceed 2^128 - 1. */
64-
val r1 = high(r2c) + high(x) + low(w)
65-
66-
var abl = (r3 shl 32) or r4 // low 64 bits of a * b
67-
var abh = (r1 shl 32) or r2 // high 64 bits of a * b
68-
69-
/** For [bit] in [0; 63], return bit #[bit] of [value], counting from the least significant bit */
70-
inline fun indexBit(value: Long, bit: Int) = (value shr bit) and 1
71-
72-
val sign = if (indexBit(abh, 63) == 1L) -1 else 1
73-
74-
if (sign == -1) {
75-
// negate, so that we operate on a positive number
76-
abl = abl.inv() + 1
77-
abh = abh.inv()
78-
if (abl == 0L) // abl overflowed
79-
abh += 1
80-
}
81-
82-
/* The resulting quotient. This division is unsigned, so if the result doesn't fit in 63 bits, it means that
83-
overflow occurred. */
84-
var q = 0L
85-
// The remainder, always less than c and so fits in a Long.
86-
var r = 0L
87-
// Simple long division algorithm
88-
for (bitNo in 127 downTo 0) {
89-
// bit #bitNo of the numerator
90-
val nextBit = if (bitNo < 64) indexBit(abl, bitNo) else indexBit(abh, bitNo - 64)
91-
// left-shift R by one bit, setting the least significant bit to nextBit
92-
r = (r shl 1) or nextBit
93-
// if (R >= c). If R < 0, R >= 2^63 > Long.MAX_VALUE >= c
94-
if (r >= c || r < 0) {
95-
r -= c
96-
// set bit #bitNo of Q to 1
97-
if (bitNo < 63)
98-
q = q or (1L shl bitNo)
99-
else
100-
throw ArithmeticException("The result of a multiplication followed by division overflows a long")
101-
}
102-
}
103-
return Pair(sign * q, sign * r)
104-
}
105-
106-
/**
107-
* Calculates ([d] * [n] + [r]) / [m], where [n], [m] > 0 and |[r]| <= [n].
108-
*
109-
* @throws ArithmeticException if the result overflows a long
110-
*/
111-
internal fun multiplyAddAndDivide(d: Long, n: Long, r: Long, m: Long): Long {
112-
var md = d
113-
var mr = r
114-
// make sure [md] and [mr] have the same sign
115-
if (d > 0 && r < 0) {
116-
md--
117-
mr += n
118-
} else if (d < 0 && r > 0) {
119-
md++
120-
mr -= n
121-
}
122-
if (md == 0L) {
123-
return mr / m
124-
}
125-
val (rd, rr) = multiplyAndDivide(md, n, m)
126-
return safeAdd(rd, safeAdd(mr / m, safeAdd(mr % m, rr) / m))
127-
}
128-
12912
/**
13013
* All code below was taken from various places of https://github.com/ThreeTen/threetenbp with few changes
13114
*/
13215

133-
internal const val NANOS_PER_MILLI = 1_000_000
134-
internal const val MILLIS_PER_ONE = 1_000
135-
internal const val NANOS_PER_ONE = 1_000_000_000
13616

13717
/**
13818
* The number of seconds per hour.

0 commit comments

Comments
 (0)