Skip to content

Commit 5655d71

Browse files
gwhitneyDelaney
andcommitted
Feat: Enhance Kronecker product to handle arbitrary dimension (#3461)
Previously `math.kron()` always returned a 2D matrix, and could not handle 3D or greater arrays. Now it always returns an array of the max dimension of its arguments. Resolves #1753. --------- Co-authored-by: Delaney Sylvans <delaneysylvans@gmail.com>
1 parent e8caa78 commit 5655d71

File tree

4 files changed

+80
-40
lines changed

4 files changed

+80
-40
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,5 +269,6 @@ Don McCurdy <dm@donmccurdy.com>
269269
Jay Chang <96050090+JayChang4w@users.noreply.github.com>
270270
mrft <977655+mrft@users.noreply.github.com>
271271
Kip Robinson <91914404+kiprobinsonknack@users.noreply.github.com>
272+
Delaney Sylvans <delaneysylvans@gmail.com>
272273

273274
# Generated by tools/update-authors.js

HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
!!! BE CAREFUL: BREAKING CHANGES !!!
66

7+
- Fix: #1753 Correct dimensionality of Kronecker product on vectors (and
8+
extend to arbitrary dimension) (#3455). Thanks @Delaney.
79
- Feat: #3349 Decouple precedence of unary percentage operator and binary
810
modulus operator (that both use symbol `%`), and raise the former (#3432).
911
Thanks @kiprobinsonknack.

src/function/matrix/kron.js

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ export const createKron = /* #__PURE__ */ factory(name, dependencies, ({ typed,
2222
* // returns [ [ 1, 2, 0, 0 ], [ 3, 4, 0, 0 ], [ 0, 0, 1, 2 ], [ 0, 0, 3, 4 ] ]
2323
*
2424
* math.kron([1,1], [2,3,4])
25-
* // returns [ [ 2, 3, 4, 2, 3, 4 ] ]
25+
* // returns [2, 3, 4, 2, 3, 4]
2626
*
2727
* See also:
2828
*
@@ -49,39 +49,38 @@ export const createKron = /* #__PURE__ */ factory(name, dependencies, ({ typed,
4949
})
5050

5151
/**
52-
* Calculate the Kronecker product of two matrices / vectors
53-
* @param {Array} a First vector
54-
* @param {Array} b Second vector
55-
* @returns {Array} Returns the Kronecker product of x and y
56-
* @private
52+
* Calculate the Kronecker product of two (1-dimensional) vectors,
53+
* with no dimension checking
54+
* @param {Array} a First vector
55+
* @param {Array} b Second vector
56+
* @returns {Array} the 1-dimensional Kronecker product of a and b
57+
* @private
58+
*/
59+
function _kron1d (a, b) {
60+
// TODO in core overhaul: would be faster to see if we can choose a
61+
// particular implementation of multiplyScalar at the beginning,
62+
// rather than re-dispatch for _every_ ordered pair of entries.
63+
return a.flatMap(x => b.map(y => multiplyScalar(x, y)))
64+
}
65+
66+
/**
67+
* Calculate the Kronecker product of two possibly multidimensional arrays
68+
* @param {Array} a First array
69+
* @param {Array} b Second array
70+
* @param {number} [d] common dimension; if missing, compute and match args
71+
* @returns {Array} Returns the Kronecker product of x and y
72+
* @private
5773
*/
58-
function _kron (a, b) {
59-
// Deal with the dimensions of the matricies.
60-
if (size(a).length === 1) {
61-
// Wrap it in a 2D Matrix
62-
a = [a]
63-
}
64-
if (size(b).length === 1) {
65-
// Wrap it in a 2D Matrix
66-
b = [b]
67-
}
68-
if (size(a).length > 2 || size(b).length > 2) {
69-
throw new RangeError('Vectors with dimensions greater then 2 are not supported expected ' +
70-
'(Size x = ' + JSON.stringify(a.length) + ', y = ' + JSON.stringify(b.length) + ')')
74+
function _kron (a, b, d = -1) {
75+
if (d < 0) {
76+
let adim = size(a).length
77+
let bdim = size(b).length
78+
d = Math.max(adim, bdim)
79+
while (adim++ < d) a = [a]
80+
while (bdim++ < d) b = [b]
7181
}
72-
const t = []
73-
let r = []
7482

75-
return a.map(function (a) {
76-
return b.map(function (b) {
77-
r = []
78-
t.push(r)
79-
return a.map(function (y) {
80-
return b.map(function (x) {
81-
return r.push(multiplyScalar(y, x))
82-
})
83-
})
84-
})
85-
}) && t
83+
if (d === 1) return _kron1d(a, b)
84+
return a.flatMap(aSlice => b.map(bSlice => _kron(aSlice, bSlice, d - 1)))
8685
}
8786
})

test/unit-tests/function/matrix/kron.test.js

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import math from '../../../../src/defaultInstance.js'
55

66
describe('kron', function () {
77
it('should calculate the Kronecker product of two arrays', function () {
8+
assert.deepStrictEqual(math.kron([[2]], [[3]]), [[6]])
89
assert.deepStrictEqual(math.kron([
910
[1, -2, 1],
1011
[1, 1, 0]
@@ -28,14 +29,27 @@ describe('kron', function () {
2829
])
2930
})
3031

32+
it('should calculate product for empty 1D Arrays', function () {
33+
assert.deepStrictEqual(math.kron([], []), [])
34+
})
35+
3136
it('should calculate product for empty 2D Arrays', function () {
3237
assert.deepStrictEqual(math.kron([[]], [[]]), [[]])
3338
})
3439

3540
it('should calculate product for 1D Arrays', function () {
41+
assert.deepStrictEqual(math.kron([2], [3]), [6])
42+
assert.deepStrictEqual(math.kron([1, 2], [3, 4]), [3, 4, 6, 8])
43+
assert.deepStrictEqual(math.kron([1, 2, 6, 8], [12, 1, 2, 3]), [12, 1, 2, 3, 24, 2, 4, 6, 72, 6, 12, 18, 96, 8, 16, 24])
44+
})
45+
46+
it('should calculate product for 1D & 2D Arrays', function () {
3647
assert.deepStrictEqual(math.kron([1, 1], [[1, 0], [0, 1]]), [[1, 0, 1, 0], [0, 1, 0, 1]])
3748
assert.deepStrictEqual(math.kron([[1, 0], [0, 1]], [1, 1]), [[1, 1, 0, 0], [0, 0, 1, 1]])
38-
assert.deepStrictEqual(math.kron([1, 2, 6, 8], [12, 1, 2, 3]), [[12, 1, 2, 3, 24, 2, 4, 6, 72, 6, 12, 18, 96, 8, 16, 24]])
49+
assert.deepStrictEqual(math.kron([[1, 2]], [[1, 2, 3]]), [[1, 2, 3, 2, 4, 6]])
50+
assert.deepStrictEqual(math.kron([[1], [2]], [[1], [2], [3]]), [[1], [2], [3], [2], [4], [6]])
51+
assert.deepStrictEqual(math.kron([[1, 2]], [[1], [2], [3]]), [[1, 2], [2, 4], [3, 6]])
52+
assert.deepStrictEqual(math.kron([[1], [2]], [[1, 2, 3]]), [[1, 2, 3], [2, 4, 6]])
3953
})
4054

4155
it('should support complex numbers', function () {
@@ -55,10 +69,32 @@ describe('kron', function () {
5569
])
5670
})
5771

58-
it('should throw an error for greater then 2 dimensions', function () {
59-
assert.throws(function () {
60-
math.kron([[[1, 1], [1, 1]], [[1, 1], [1, 1]]], [[[1, 2, 3], [4, 5, 6]], [[6, 7, 8], [9, 10, 11]]])
61-
})
72+
it('should calculate a 3D Kronecker product', function () {
73+
assert.deepStrictEqual(
74+
math.kron([
75+
[[1, 2], [2, 3]],
76+
[[2, 3], [3, 4]]
77+
], [
78+
[[4, 3], [3, 2]],
79+
[[3, 2], [2, 1]]
80+
]), [
81+
/* eslint-disable no-multi-spaces, array-bracket-spacing */
82+
[[4, 3, 8, 6], [3, 2, 6, 4], [ 8, 6, 12, 9], [6, 4, 9, 6]],
83+
[[3, 2, 6, 4], [2, 1, 4, 2], [ 6, 4, 9, 6], [4, 2, 6, 3]],
84+
[[8, 6, 12, 9], [6, 4, 9, 6], [12, 9, 16, 12], [9, 6, 12, 8]],
85+
[[6, 4, 9, 6], [4, 2, 6, 3], [ 9, 6, 12, 8], [6, 3, 8, 4]]
86+
/* eslint-enable */
87+
]
88+
)
89+
})
90+
91+
it('should allow mixed-dimensional Kronecker products', function () {
92+
const b = [[[4, 3], [3, 2]], [[3, 2], [2, 1]]]
93+
const a = [1, 2]
94+
assert.deepStrictEqual(math.kron(a, b), math.kron([[a]], b))
95+
assert.deepStrictEqual(math.kron([a], b), math.kron([[a]], b))
96+
assert.deepStrictEqual(math.kron(b, a), math.kron(b, [[a]]))
97+
assert.deepStrictEqual(math.kron(b, [a]), math.kron(b, [[a]]))
6298
})
6399

64100
it('should throw an error if called with an invalid number of arguments', function () {
@@ -81,10 +117,12 @@ describe('kron', function () {
81117
assert.deepStrictEqual(product.toArray(), [[13, 26, 0, 0], [715, -13, 0, -0], [0, 0, -1, -2], [0, -0, -55, 1]])
82118
})
83119

84-
it('should throw an error for invalid Kronecker product of matrix', function () {
85-
const y = math.matrix([[[]]])
120+
it('should calculate the Kronecker product of 3d matrices', function () {
121+
const y = math.matrix([[[3]]])
86122
const x = math.matrix([[[1, 1], [1, 1]], [[1, 1], [1, 1]]])
87-
assert.throws(function () { math.kron(y, x) })
123+
const product = math.kron(x, y)
124+
assert.deepStrictEqual(
125+
product.toArray(), [[[3, 3], [3, 3]], [[3, 3], [3, 3]]])
88126
})
89127
})
90128

0 commit comments

Comments
 (0)