Skip to content

Commit 75f557b

Browse files
taha-yassinexenova
andauthored
Implement numerically stable log_softmax() (#812)
* Implement numerically stable log_softmax() * Add unit tests * Update src/utils/maths.js --------- Co-authored-by: Joshua Lochner <[email protected]>
1 parent fc34517 commit 75f557b

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

src/utils/maths.js

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,20 @@ export function softmax(arr) {
158158
* @returns {T} The resulting log_softmax array.
159159
*/
160160
export function log_softmax(arr) {
161-
// Compute the softmax values
162-
const softmaxArr = softmax(arr);
161+
// Compute the maximum value in the array
162+
const maxVal = max(arr)[0];
163+
164+
// Compute the sum of the exponentials
165+
let sumExps = 0;
166+
for(let i = 0; i < arr.length; ++i) {
167+
sumExps += Math.exp(arr[i] - maxVal);
168+
}
163169

164-
// Apply log formula to each element
165-
const logSoftmaxArr = softmaxArr.map(x => Math.log(x));
170+
// Compute the log of the sum
171+
const logSum = Math.log(sumExps);
172+
173+
// Compute the softmax values
174+
const logSoftmaxArr = arr.map(x => x - maxVal - logSum);
166175

167176
return /** @type {T} */(logSoftmaxArr);
168177
}

tests/maths.test.js

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import { compare } from './test_utils.js';
33

44
import { getFile } from '../src/utils/hub.js';
5-
import { FFT, medianFilter, bankers_round } from '../src/utils/maths.js';
5+
import { FFT, medianFilter, bankers_round, log_softmax } from '../src/utils/maths.js';
66

77

88
const fft = (arr, complex = false) => {
@@ -136,4 +136,21 @@ describe('Mathematical operations', () => {
136136
});
137137
}
138138
});
139+
140+
describe('log softmax', () => {
141+
// Should match output of scipy log_softmax
142+
it('should compute log softmax correctly for usual values', () => {
143+
const input = [0, 1, 2, 3];
144+
const expected = [-3.4401896985611953, -2.4401896985611953, -1.4401896985611953, -0.44018969856119533];
145+
const output = log_softmax(input);
146+
compare(output, expected, 1e-13);
147+
});
148+
149+
it('should compute log softmax correctly for values with large differences', () => {
150+
const input = [1000, 1];
151+
const expected = [0, -999];
152+
const output = log_softmax(input);
153+
compare(output, expected, 1e-13);
154+
});
155+
});
139156
});

0 commit comments

Comments
 (0)