1- import { deepForEach } from '../../utils/collection.js'
21import { factory } from '../../utils/factory.js'
2+ import { isArray , isNumber } from '../../utils/is.js'
33import { safeNumberType } from '../../utils/number.js'
44import { improveErrorMessage } from './utils/improveErrorMessage.js'
55
66const name = 'prod'
77const dependencies = [
8- 'typed' , 'config' , 'multiplyScalar' , 'numeric'
8+ 'typed' , 'config' , 'multiplyScalar' , 'number' , 'numeric' , '?Index' , 'Range' ,
9+ 'squeeze' , 'size' , 'subset' , 'dotMultiply'
910]
1011
12+ const THRESHOLD = 16 // where to stop splitting and switch to direct multiply
13+
1114export const createProd = /* #__PURE__ */ factory ( name , dependencies , ( {
12- typed, config, multiplyScalar, numeric
15+ typed, config, multiplyScalar, number, numeric, Index, Range,
16+ squeeze, size, subset, dotMultiply
1317} ) => {
1418 /**
1519 * Compute the product of a matrix or a list with values.
@@ -42,11 +46,7 @@ export const createProd = /* #__PURE__ */ factory(name, dependencies, ({
4246 'Array | Matrix' : _prod ,
4347
4448 // prod([a, b, c, d, ...], dim)
45- 'Array | Matrix, number | BigNumber' : function ( array , dim ) {
46- // TODO: implement prod(A, dim)
47- throw new Error ( 'prod(A, dim) is not yet supported' )
48- // return reduce(arguments[0], arguments[1], math.prod)
49- } ,
49+ 'Array | Matrix, number | BigNumber' : _prodAlongDim ,
5050
5151 // prod(a, b, c, d, ...)
5252 '...' : function ( args ) {
@@ -56,27 +56,94 @@ export const createProd = /* #__PURE__ */ factory(name, dependencies, ({
5656
5757 /**
5858 * Recursively calculate the product of an n-dimensional array
59- * @param {Array } array
60- * @return {number } prod
59+ * @param {Array | Matrix } collection
60+ * @return {scalar } prod
6161 * @private
6262 */
63- function _prod ( array ) {
63+ function _prod ( collection ) {
64+ let sz = size ( collection )
65+ if ( sz . length === 0 || sz . some ( dim => dim === 0 ) ) return 1
6466 let prod
67+ try {
68+ if ( sz . every ( dim => dim === 1 ) ) prod = squeeze ( collection )
69+ else {
70+ if ( sz . length > 1 ) { // reduce to 1d
71+ const newColl = [ ]
72+ for ( let pos = 0 ; pos < sz [ 0 ] ; ++ pos ) {
73+ newColl . push ( _prod ( subset ( collection , pos ) ) )
74+ }
75+ collection = newColl
76+ sz = [ sz [ 0 ] ]
77+ }
78+ if ( ! Index ) collection = collection . valueOf ( )
79+ if ( Array . isArray ( collection ) ) {
80+ prod = _prodArray ( collection , 0 , sz [ 0 ] - 1 )
81+ } else {
82+ let op = multiplyScalar
83+ const dt = collection . datatype ( )
84+ if ( dt ) op = typed . find ( op , [ dt , dt ] )
85+ prod = _prodVector ( collection , 0 , sz [ 0 ] - 1 , op )
86+ }
87+ }
6588
66- deepForEach ( array , function ( value ) {
67- try {
68- prod = ( prod === undefined ) ? value : multiplyScalar ( prod , value )
69- } catch ( err ) {
70- throw improveErrorMessage ( err , 'prod' , value )
89+ if ( typeof prod === 'string' ) {
90+ prod = numeric ( prod , safeNumberType ( prod , config ) )
7191 }
72- } )
92+ } catch ( err ) {
93+ throw improveErrorMessage ( err , 'prod' , collection )
94+ }
95+ return prod
96+ }
7397
74- // make sure returning numeric value: parse a string into a numeric value
75- if ( typeof prod === 'string' ) {
76- prod = numeric ( prod , safeNumberType ( prod , config ) )
98+ /* Product of a 1d array arr from index first to index last, inclusive. */
99+ function _prodArray ( arr , first , last ) {
100+ if ( last - first < THRESHOLD ) {
101+ let prod = arr [ first ]
102+ for ( let idx = first + 1 ; idx <= last ; ++ idx ) {
103+ prod = multiplyScalar ( prod , arr [ idx ] )
104+ }
105+ return prod
77106 }
107+ const cutoff = Math . floor ( ( first + last ) / 2 )
108+ return multiplyScalar (
109+ _prodArray ( arr , first , cutoff ) ,
110+ _prodArray ( arr , cutoff + 1 , last ) )
111+ }
78112
79- if ( prod === undefined ) return 1
80- return prod
113+ /* Product of a 1d vector v from position first to last, using op */
114+ function _prodVector ( v , first , last , op ) {
115+ if ( last - first < THRESHOLD ) {
116+ let prod = v . layer ( first )
117+ for ( let idx = first + 1 ; idx <= last ; ++ idx ) {
118+ prod = op ( prod , v . layer ( idx ) )
119+ }
120+ return prod
121+ }
122+ const cutoff = Math . floor ( ( first + last ) / 2 )
123+ return op (
124+ _prodVector ( v , first , cutoff , op ) , _prodVector ( v , cutoff + 1 , last , op ) )
125+ }
126+
127+ function _prodAlongDim ( collection , dim ) {
128+ if ( ! isNumber ( dim ) ) dim = number ( dim )
129+ const sz = size ( collection )
130+ if ( dim >= sz . length ) {
131+ throw new Error (
132+ `There is no dimension ${ dim } in collection of size ${ sz } .` )
133+ }
134+ if ( sz . length === 1 ) return _prod ( collection )
135+ if ( dim === 0 ) {
136+ let result = subset ( collection , 0 )
137+ for ( let i = 1 ; i < sz [ 0 ] ; ++ i ) {
138+ result = dotMultiply ( result , subset ( collection , i ) )
139+ }
140+ return result
141+ }
142+ const data = [ ]
143+ for ( let i = 0 ; i < sz [ 0 ] ; ++ i ) {
144+ data . push ( _prodAlongDim ( subset ( collection , i ) , dim - 1 ) . valueOf ( ) )
145+ }
146+ if ( isArray ( collection ) ) return data
147+ return collection . create ( data )
81148 }
82149} )
0 commit comments