Skip to content

Commit 1b9fe0f

Browse files
committed
Replace Math.max with iterative version in softmax function (#83)
TODO: In future, replace all min/max/indexOfMin/indexOfMax functions with functions that return index and value of the minimum/maximum at the same time (like PyTorch).
1 parent ee1056b commit 1b9fe0f

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

src/utils.js

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ const { env } = require('./env.js');
66
if (global.ReadableStream === undefined && typeof process !== 'undefined') {
77
try {
88
// @ts-ignore
9-
global.ReadableStream = require('node:stream/web').ReadableStream; // ReadableStream is not a global with Node 16
10-
} catch(err) {
9+
global.ReadableStream = require('node:stream/web').ReadableStream; // ReadableStream is not a global with Node 16
10+
} catch (err) {
1111
console.warn("ReadableStream not defined and unable to import from node:stream/web");
1212
}
1313
}
@@ -429,10 +429,10 @@ function indexOfMax(arr) {
429429
*/
430430
function softmax(arr) {
431431
// Compute the maximum value in the array
432-
const max = Math.max(...arr);
432+
const maxVal = max(arr);
433433

434434
// Compute the exponentials of the array values
435-
const exps = arr.map(x => Math.exp(x - max));
435+
const exps = arr.map(x => Math.exp(x - maxVal));
436436

437437
// Compute the sum of the exponentials
438438
const sumExps = exps.reduce((acc, val) => acc + val, 0);
@@ -584,6 +584,24 @@ function min(arr) {
584584
return min;
585585
}
586586

587+
588+
/**
589+
* Returns the maximum item.
590+
* @param {number[]} arr - array of numbers.
591+
* @returns {number} - the maximum number.
592+
* @throws {Error} If array is empty.
593+
*/
594+
function max(arr) {
595+
if (arr.length === 0) throw Error('Array must not be empty');
596+
let max = arr[0];
597+
for (let i = 1; i < arr.length; ++i) {
598+
if (arr[i] > max) {
599+
max = arr[i];
600+
}
601+
}
602+
return max;
603+
}
604+
587605
/**
588606
* Check if a value is a string.
589607
* @param {*} text - The value to check.
@@ -630,5 +648,6 @@ module.exports = {
630648
isIntegralNumber,
631649
isString,
632650
exists,
633-
min
651+
min,
652+
max,
634653
};

0 commit comments

Comments
 (0)