Skip to content

Commit 255c778

Browse files
committed
refactor: Speedup prod
The prod function now first reduces its argument to a 1d vector, then uses binary splitting without allocating any new matrices or arrays. Also implements the previously unimplemented second "dimension" argument to prod (albeit not as efficiently). As these changes added new dependencies to `prod`, some other minor refactors were necessary to avoid circular dependencies and/or keep unwanted dependencies from the number bundle. Also fixes a couple more instances of doc example tests.
1 parent 6ba5987 commit 255c778

File tree

16 files changed

+230
-55
lines changed

16 files changed

+230
-55
lines changed

src/factoriesNumber.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ export const createSubtractScalar = /* #__PURE__ */ createNumberFactory('subtrac
106106
export const createCbrt = /* #__PURE__ */ createNumberFactory('cbrt', cbrtNumber)
107107
export { createCeilNumber as createCeil } from './function/arithmetic/ceil.js'
108108
export const createCube = /* #__PURE__ */ createNumberFactory('cube', cubeNumber)
109+
export { createDotMultiplyNumber } from './function/arithmetic/dotMultiply.js'
109110
export const createExp = /* #__PURE__ */ createNumberFactory('exp', expNumber)
110111
export const createExpm1 = /* #__PURE__ */ createNumberFactory('expm1', expm1Number)
111112
export { createFixNumber as createFix } from './function/arithmetic/fix.js'
@@ -116,7 +117,7 @@ export const createLog10 = /* #__PURE__ */ createNumberFactory('log10', log10Num
116117
export const createLog2 = /* #__PURE__ */ createNumberFactory('log2', log2Number)
117118
export const createMod = /* #__PURE__ */ createNumberFactory('mod', modNumber)
118119
export const createMultiplyScalar = /* #__PURE__ */ createNumberFactory('multiplyScalar', multiplyNumber)
119-
export const createMultiply = /* #__PURE__ */ createNumberFactory('multiply', multiplyNumber)
120+
export { createMultiplyNumber } from './function/arithmetic/multiply.js'
120121
export const createNthRoot = /* #__PURE__ */
121122
createNumberOptionalSecondArgFactory('nthRoot', nthRootNumber)
122123
export const createSign = /* #__PURE__ */ createNumberFactory('sign', signNumber)
@@ -224,6 +225,8 @@ export { createSize } from './function/matrix/size.js'
224225
export const createIndex = /* #__PURE__ */ factory('index', [], () => noIndex)
225226
export const createMatrix = /* #__PURE__ */ factory('matrix', [], () => noMatrix) // FIXME: needed now because subset transform needs it. Remove the need for it in subset
226227
export const createSubset = /* #__PURE__ */ factory('subset', [], () => noSubset)
228+
export { createSqueeze } from './function/matrix/squeeze.js'
229+
227230
// TODO: provide number+array implementations for map, filter, forEach, zeros, ...?
228231
// TODO: create range implementation for range?
229232
export { createPartitionSelect } from './function/matrix/partitionSelect.js'

src/function/arithmetic/dotMultiply.js

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import { factory } from '../../utils/factory.js'
2+
import { isArray } from '../../utils/is.js'
3+
import { deepMultiply } from '../../plain/number/arithmetic.js'
24
import { createMatAlgo02xDS0 } from '../../type/matrix/utils/matAlgo02xDS0.js'
35
import { createMatAlgo09xS0Sf } from '../../type/matrix/utils/matAlgo09xS0Sf.js'
46
import { createMatAlgo11xS0s } from '../../type/matrix/utils/matAlgo11xS0s.js'
@@ -51,3 +53,30 @@ export const createDotMultiply = /* #__PURE__ */ factory(name, dependencies, ({
5153
Ss: matAlgo11xS0s
5254
}))
5355
})
56+
57+
export const createDotMultiplyNumber = /* #__PURE__ */ factory(
58+
name, ['typed'], ({ typed }) => {
59+
return typed(name, {
60+
'number, number': (m, n) => m * n,
61+
'number, Array': deepMultiply,
62+
'Array, number': (A, n) => deepMultiply(n, A),
63+
'Array, Array': _dotMult
64+
})
65+
}
66+
)
67+
68+
/* Multiply corresponding entries of A and B */
69+
function _dotMult (A, B) {
70+
if (A.length !== B.length) {
71+
throw new Error('Cannot dot-multiply arrays of differing length.')
72+
}
73+
return A.map((a, ix) => {
74+
const b = B[ix]
75+
if (isArray(a)) {
76+
if (isArray(b)) return _dotMult(a, b)
77+
} else {
78+
if (!isArray(b)) return a * b
79+
}
80+
throw new Error('Cannot dot-multiply arrays of different shape.')
81+
})
82+
}

src/function/arithmetic/multiply.js

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { factory } from '../../utils/factory.js'
22
import { isMatrix } from '../../utils/is.js'
33
import { arraySize } from '../../utils/array.js'
4+
import { deepMultiply } from '../../plain/number/arithmetic.js'
45
import { createMatAlgo11xS0s } from '../../type/matrix/utils/matAlgo11xS0s.js'
56
import { createMatAlgo14xDs } from '../../type/matrix/utils/matAlgo14xDs.js'
67

@@ -880,3 +881,14 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
880881
})
881882
})
882883
})
884+
885+
export const createMultiplyNumber = /* #__PURE__ */ factory(
886+
name, ['typed'], ({ typed }) => {
887+
return typed(name, {
888+
'number, number': (m, n) => m * n,
889+
'bigint, bigint': (m, n) => m * n,
890+
'number | bigint, Array': deepMultiply,
891+
'Array, number | bigint': (A, n) => deepMultiply(n, A)
892+
})
893+
}
894+
)

src/function/arithmetic/scalarDivide.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ const dependencies = [
99
]
1010

1111
export const createScalarDivide = /* #__PURE__ */ factory(name, dependencies, ({
12-
typed, Unit, map, multiply, equal, deepEqual,
12+
typed, map, multiply, equal, deepEqual,
1313
isInteger, isNumeric, isZero,
1414
abs, add, divide, fraction
1515
}) => {
@@ -102,7 +102,7 @@ export const createScalarDivide = /* #__PURE__ */ factory(name, dependencies, ({
102102
return quotient
103103
}
104104
if (isUnit(quotient)) {
105-
if (quotient.equalBase(Unit.BASE_UNITS.NONE)) return quotient.value
105+
if (quotient.unitless()) return quotient.value
106106
return quotient
107107
}
108108
return undefined

src/function/matrix/squeeze.js

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,20 @@ export const createSqueeze = /* #__PURE__ */ factory(name, dependencies, ({ type
1818
* math.squeeze([3]) // returns 3
1919
* math.squeeze([[3]]) // returns 3
2020
*
21-
* const A = math.zeros(3, 1) // returns [[0], [0], [0]] (size 3x1)
22-
* math.squeeze(A) // returns [0, 0, 0] (size 3)
23-
*
24-
* const B = math.zeros(1, 3) // returns [[0, 0, 0]] (size 1x3)
25-
* math.squeeze(B) // returns [0, 0, 0] (size 3)
26-
*
27-
* // only inner and outer dimensions are removed
28-
* const C = math.zeros(2, 1, 3) // returns [[[0, 0, 0]], [[0, 0, 0]]] (size 2x1x3)
29-
* math.squeeze(C) // returns [[[0, 0, 0]], [[0, 0, 0]]] (size 2x1x3)
21+
* // Squeezes size 3x1 to size 3:
22+
* const A = math.zeros(3, 1)
23+
* A // Matrix [[0], [0], [0]] ...
24+
* math.squeeze(A) // Matrix [0, 0, 0]
25+
*
26+
* // Also squeezes size 1x3 to 3:
27+
* const B = math.zeros(1, 3)
28+
* B // Matrix [[0, 0, 0]] ...
29+
* math.squeeze(B) // Matrix [0, 0, 0]
30+
*
31+
* // only inner and outer dimensions are removed:
32+
* const C = math.zeros(2, 1, 3)
33+
* C // Matrix [[[0, 0, 0]], [[0, 0, 0]]] ...
34+
* math.squeeze(C) // Matrix [[[0, 0, 0]], [[0, 0, 0]]]
3035
*
3136
* See also:
3237
*
Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
11
import { factory } from '../../utils/factory.js'
22

33
export const createCompareUnits = /* #__PURE__ */ factory(
4-
'compareUnits', ['typed'], ({ typed }) => ({
4+
'compareUnits', ['typed'], ({ typed, Unit }) => ({
55
'Unit, Unit': typed.referToSelf(self => (x, y) => {
66
if (!x.equalBase(y)) {
77
throw new Error('Cannot compare units with different base')
88
}
99
return typed.find(self, [x.valueType(), y.valueType()])(x.value, y.value)
10-
})
10+
}),
11+
'Unit, number | bigint | BigNumber | Fraction | Complex': typed.referToSelf(
12+
self => (x, y) => {
13+
if (!x.unitless()) {
14+
throw new Error('To compare Unit with pure numeric, must be unitless')
15+
}
16+
return self(x.value, y)
17+
}),
18+
'number | bigint | BigNumber | Fraction | Complex, Unit': typed.referToSelf(
19+
self => (x, y) => {
20+
if (!y.unitless()) {
21+
throw new Error('To compare Unit with pure numeric, must be unitless')
22+
}
23+
return self(x, y.value)
24+
})
1125
})
1226
)

src/function/relational/equalScalar.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ import { createCompareUnits } from './compareUnits.js'
77
const name = 'equalScalar'
88
const dependencies = ['typed', 'config']
99

10-
export const createEqualScalar = /* #__PURE__ */ factory(name, dependencies, ({ typed, config }) => {
10+
export const createEqualScalar = /* #__PURE__ */ factory(name, dependencies, ({ typed, config, Unit }) => {
1111
const compareUnits = createCompareUnits({ typed })
1212

1313
/**
1414
* Test whether two scalar values are nearly equal.
1515
*
1616
* @param {number | BigNumber | bigint | Fraction | boolean | Complex | Unit} x First value to compare
17-
* @param {number | BigNumber | bigint | Fraction | boolean | Complex} y Second value to compare
17+
* @param {number | BigNumber | bigint | Fraction | boolean | Complex | Unit} y Second value to compare
1818
* @return {boolean} Returns true when the compared values are equal, else returns false
1919
* @private
2020
*/

src/function/set/setUnion.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ export const createSetUnion = /* #__PURE__ */ factory(name, dependencies, ({ typ
1515
*
1616
* Examples:
1717
*
18-
* math.setUnion([1, 2, 3, 4], [3, 4, 5, 6]) // returns [1, 2, 3, 4, 5, 6]
19-
* math.setUnion([[1, 2], [3, 4]], [[3, 4], [5, 6]]) // returns [1, 2, 3, 4, 5, 6]
18+
* math.sort(math.setUnion([1, 2, 3, 4], [3, 4, 5, 6])) // returns [1, 2, 3, 4, 5, 6]
19+
* math.sort(math.setUnion([[1, 2], [3, 4]], [[3, 4], [5, 6]])) // returns [1, 2, 3, 4, 5, 6]
2020
*
2121
* See also:
2222
*

src/function/statistics/prod.js

Lines changed: 89 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
import { deepForEach } from '../../utils/collection.js'
21
import { factory } from '../../utils/factory.js'
2+
import { isArray, isNumber } from '../../utils/is.js'
33
import { safeNumberType } from '../../utils/number.js'
44
import { improveErrorMessage } from './utils/improveErrorMessage.js'
55

66
const name = 'prod'
77
const dependencies = [
8-
'typed', 'config', 'multiplyScalar', 'numeric'
8+
'typed', 'config', 'multiplyScalar', 'number', 'numeric', '?Index', 'Range',
9+
'squeeze', 'size', 'subset', 'dotMultiply'
910
]
1011

12+
const THRESHOLD = 16 // where to stop splitting and switch to direct multiply
13+
1114
export const createProd = /* #__PURE__ */ factory(name, dependencies, ({
12-
typed, config, multiplyScalar, numeric
15+
typed, config, multiplyScalar, number, numeric, Index, Range,
16+
squeeze, size, subset, dotMultiply
1317
}) => {
1418
/**
1519
* Compute the product of a matrix or a list with values.
@@ -42,11 +46,7 @@ export const createProd = /* #__PURE__ */ factory(name, dependencies, ({
4246
'Array | Matrix': _prod,
4347

4448
// prod([a, b, c, d, ...], dim)
45-
'Array | Matrix, number | BigNumber': function (array, dim) {
46-
// TODO: implement prod(A, dim)
47-
throw new Error('prod(A, dim) is not yet supported')
48-
// return reduce(arguments[0], arguments[1], math.prod)
49-
},
49+
'Array | Matrix, number | BigNumber': _prodAlongDim,
5050

5151
// prod(a, b, c, d, ...)
5252
'...': function (args) {
@@ -56,27 +56,94 @@ export const createProd = /* #__PURE__ */ factory(name, dependencies, ({
5656

5757
/**
5858
* Recursively calculate the product of an n-dimensional array
59-
* @param {Array} array
60-
* @return {number} prod
59+
* @param {Array | Matrix} collection
60+
* @return {scalar} prod
6161
* @private
6262
*/
63-
function _prod (array) {
63+
function _prod (collection) {
64+
let sz = size(collection)
65+
if (sz.length === 0 || sz.some(dim => dim === 0)) return 1
6466
let prod
67+
try {
68+
if (sz.every(dim => dim === 1)) prod = squeeze(collection)
69+
else {
70+
if (sz.length > 1) { // reduce to 1d
71+
const newColl = []
72+
for (let pos = 0; pos < sz[0]; ++pos) {
73+
newColl.push(_prod(subset(collection, pos)))
74+
}
75+
collection = newColl
76+
sz = [sz[0]]
77+
}
78+
if (!Index) collection = collection.valueOf()
79+
if (Array.isArray(collection)) {
80+
prod = _prodArray(collection, 0, sz[0] - 1)
81+
} else {
82+
let op = multiplyScalar
83+
const dt = collection.datatype()
84+
if (dt) op = typed.find(op, [dt, dt])
85+
prod = _prodVector(collection, 0, sz[0] - 1, op)
86+
}
87+
}
6588

66-
deepForEach(array, function (value) {
67-
try {
68-
prod = (prod === undefined) ? value : multiplyScalar(prod, value)
69-
} catch (err) {
70-
throw improveErrorMessage(err, 'prod', value)
89+
if (typeof prod === 'string') {
90+
prod = numeric(prod, safeNumberType(prod, config))
7191
}
72-
})
92+
} catch (err) {
93+
throw improveErrorMessage(err, 'prod', collection)
94+
}
95+
return prod
96+
}
7397

74-
// make sure returning numeric value: parse a string into a numeric value
75-
if (typeof prod === 'string') {
76-
prod = numeric(prod, safeNumberType(prod, config))
98+
/* Product of a 1d array arr from index first to index last, inclusive. */
99+
function _prodArray (arr, first, last) {
100+
if (last - first < THRESHOLD) {
101+
let prod = arr[first]
102+
for (let idx = first + 1; idx <= last; ++idx) {
103+
prod = multiplyScalar(prod, arr[idx])
104+
}
105+
return prod
77106
}
107+
const cutoff = Math.floor((first + last) / 2)
108+
return multiplyScalar(
109+
_prodArray(arr, first, cutoff),
110+
_prodArray(arr, cutoff + 1, last))
111+
}
78112

79-
if (prod === undefined) return 1
80-
return prod
113+
/* Product of a 1d vector v from position first to last, using op */
114+
function _prodVector (v, first, last, op) {
115+
if (last - first < THRESHOLD) {
116+
let prod = v.layer(first)
117+
for (let idx = first + 1; idx <= last; ++idx) {
118+
prod = op(prod, v.layer(idx))
119+
}
120+
return prod
121+
}
122+
const cutoff = Math.floor((first + last) / 2)
123+
return op(
124+
_prodVector(v, first, cutoff, op), _prodVector(v, cutoff + 1, last, op))
125+
}
126+
127+
function _prodAlongDim (collection, dim) {
128+
if (!isNumber(dim)) dim = number(dim)
129+
const sz = size(collection)
130+
if (dim >= sz.length) {
131+
throw new Error(
132+
`There is no dimension ${dim} in collection of size ${sz}.`)
133+
}
134+
if (sz.length === 1) return _prod(collection)
135+
if (dim === 0) {
136+
let result = subset(collection, 0)
137+
for (let i = 1; i < sz[0]; ++i) {
138+
result = dotMultiply(result, subset(collection, i))
139+
}
140+
return result
141+
}
142+
const data = []
143+
for (let i = 0; i < sz[0]; ++i) {
144+
data.push(_prodAlongDim(subset(collection, i), dim - 1).valueOf())
145+
}
146+
if (isArray(collection)) return data
147+
return collection.create(data)
81148
}
82149
})

src/plain/number/arithmetic.js

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ export function multiplyNumber (a, b) {
2323
}
2424
multiplyNumber.signature = n2
2525

26+
/* Multiply every entry of A by the number n */
27+
export function deepMultiply (n, A) {
28+
return A.map(item => {
29+
if (Array.isArray(item)) return deepMultiply(n, item)
30+
return n * item
31+
})
32+
}
33+
2634
export function divideNumber (a, b) {
2735
return a / b
2836
}

0 commit comments

Comments
 (0)