Skip to content

Commit fc30fb4

Browse files
feat(cross entropy): introduce (#14)
1 parent 4daaca5 commit fc30fb4

File tree

8 files changed

+159
-10
lines changed

8 files changed

+159
-10
lines changed

examples/cifar-10.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import { readFileSync } from "node:fs";
2+
23
import { Model } from "../src/core/mod.ts";
34
import { Dense, ReLU, Softmax } from "../src/layers/mod.ts";
45
import { Adam } from "../src/optimizes/mod.ts";
5-
import { MeanSquaredError } from "../src/losses/mod.ts";
6+
import { CrossEntropyLoss } from "../src/losses/mod.ts";
67

78
const CIFAR_IMAGE_HEIGHT = 32;
89
const CIFAR_IMAGE_WIDTH = 32;
@@ -200,7 +201,7 @@ model.addLayer(new Softmax());
200201
// 3. Compile the Model
201202
model.compile(
202203
new Adam(0.001), // Adam optimizer
203-
new MeanSquaredError(), // Using MSE as it's available, though CrossEntropy is typical for classification
204+
new CrossEntropyLoss(), // Cross-entropy loss for multi-class classification
204205
["accuracy"], // Metric
205206
);
206207

examples/cifar-100.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import { readFileSync } from "node:fs";
2+
23
import { Model } from "../src/core/mod.ts";
34
import { Dense, ReLU, Softmax } from "../src/layers/mod.ts";
45
import { Adam } from "../src/optimizes/mod.ts";
5-
import { MeanSquaredError } from "../src/losses/mod.ts";
6+
import { CrossEntropyLoss } from "../src/losses/mod.ts";
67

78
const CIFAR100_IMAGE_HEIGHT = 32;
89
const CIFAR100_IMAGE_WIDTH = 32;
@@ -199,11 +200,7 @@ model.addLayer(new Dense(128, CIFAR100_NUM_FINE_CLASSES)); // Hidden layer to ou
199200
model.addLayer(new Softmax());
200201

201202
// 3. Compile the Model
202-
model.compile(
203-
new Adam(0.001),
204-
new MeanSquaredError(), // CrossEntropyLoss would be more appropriate for multi-class
205-
["accuracy"],
206-
);
203+
model.compile(new Adam(0.001), new CrossEntropyLoss(), ["accuracy"]);
207204

208205
// 4. Train the Model
209206
console.log("Starting model training...");

examples/mnist.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { readFileSync } from "node:fs";
22
import { Model } from "../src/core/mod.ts";
33
import { Dense, ReLU, Softmax } from "../src/layers/mod.ts";
44
import { Adam } from "../src/optimizes/mod.ts";
5-
import { MeanSquaredError } from "../src/losses/mod.ts";
5+
import { CrossEntropyLoss } from "../src/losses/mod.ts";
66

77
const MNIST_IMAGE_MAGIC_NUMBER = 2051;
88
const MNIST_LABEL_MAGIC_NUMBER = 2049;
@@ -174,7 +174,7 @@ model.addLayer(new Softmax()); // Softmax for multi-class probability output
174174
// 3. Compile the Model
175175
model.compile(
176176
new Adam(0.001), // Adam optimizer
177-
new MeanSquaredError(), // Using MSE as it's available. CrossEntropyLoss is often preferred for classification.
177+
new CrossEntropyLoss(), // Cross-entropy loss for multi-class classification
178178
["accuracy"], // Metric
179179
);
180180

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import assert from "node:assert/strict";
2+
import { describe, it } from "node:test";
3+
import { BinaryCrossEntropyLoss } from "./binary_cross_entropy.ts";
4+
5+
describe("BinaryCrossEntropyLoss", () => {
6+
it("should calculate the binary cross-entropy loss correctly", () => {
7+
const predictions = [0.9, 0.2, 0.8];
8+
const targets = [1, 0, 1];
9+
const binaryCrossEntropy = new BinaryCrossEntropyLoss();
10+
const loss = binaryCrossEntropy.calculate(predictions, targets);
11+
assert.strictEqual(loss, 0.18388253942874858);
12+
});
13+
14+
it("should throw an error for different length arrays", () => {
15+
const predictions = [0.9, 0.2];
16+
const targets = [1, 0, 1];
17+
const binaryCrossEntropy = new BinaryCrossEntropyLoss();
18+
assert.throws(
19+
() => {
20+
binaryCrossEntropy.calculate(predictions, targets);
21+
},
22+
{
23+
message: "Predictions and targets must have the same length.",
24+
},
25+
);
26+
});
27+
28+
it("should return 0 for empty arrays", () => {
29+
const predictions: number[] = [];
30+
const targets: number[] = [];
31+
const binaryCrossEntropy = new BinaryCrossEntropyLoss();
32+
const loss = binaryCrossEntropy.calculate(predictions, targets);
33+
assert.strictEqual(loss, 0);
34+
});
35+
});

src/losses/binary_cross_entropy.ts

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/**
2+
* BinaryCrossEntropyLoss calculates the binary cross-entropy loss between predictions and target values.
3+
* This loss is commonly used for binary classification tasks.
4+
*
5+
* Formula:
6+
* `L = -Σ(y_true * log(y_pred) + (1 - y_true) * log(1 - y_pred))`
7+
*
8+
* @example
9+
* ```typescript
10+
* const binaryCrossEntropy = new BinaryCrossEntropyLoss();
11+
* const predictions = [0.9, 0.2, 0.8];
12+
* const targets = [1, 0, 1];
13+
* const loss = binaryCrossEntropy.calculate(predictions, targets);
14+
* console.log("Binary CrossEntropy Loss:", loss); // Output: ~0.1839
15+
* ```
16+
*/
17+
export class BinaryCrossEntropyLoss {
18+
/**
19+
* Calculates the binary cross-entropy loss.
20+
* @param predictions An array of predicted probabilities (values between 0 and 1).
21+
* @param targets An array of binary target values (0 or 1).
22+
* @returns The calculated binary cross-entropy loss.
23+
* @throws Error if the predictions and targets arrays do not have the same length.
24+
*/
25+
calculate(predictions: number[], targets: number[]): number {
26+
if (predictions.length !== targets.length) {
27+
throw new Error("Predictions and targets must have the same length.");
28+
}
29+
30+
if (predictions.length === 0 || targets.length === 0) {
31+
return 0; // Return 0 for empty arrays
32+
}
33+
34+
const epsilon = 1e-12; // To avoid log(0)
35+
let loss = 0;
36+
37+
for (let i = 0; i < predictions.length; i++) {
38+
const yTrue = targets[i];
39+
const yPred = Math.min(Math.max(predictions[i], epsilon), 1 - epsilon); // Clamp predictions to avoid log(0)
40+
loss -= yTrue * Math.log(yPred) + (1 - yTrue) * Math.log(1 - yPred);
41+
}
42+
43+
return loss / predictions.length; // Average loss
44+
}
45+
}

src/losses/cross_entropy.test.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import assert from "node:assert/strict";
2+
import { describe, it } from "node:test";
3+
import { CrossEntropyLoss } from "./cross_entropy.ts";
4+
5+
describe("CrossEntropyLoss", () => {
6+
it("should calculate the correct loss for given predictions and targets", () => {
7+
const crossEntropy = new CrossEntropyLoss();
8+
const predictions = [0.7, 0.2, 0.1];
9+
const targets = [1, 0, 0];
10+
const loss = crossEntropy.calculate(predictions, targets);
11+
assert.strictEqual(loss, 0.1188916479791013);
12+
});
13+
14+
it("should throw an error if predictions and targets have different lengths", () => {
15+
const crossEntropy = new CrossEntropyLoss();
16+
const predictions = [0.7, 0.2];
17+
const targets = [1, 0, 0];
18+
assert.throws(
19+
() => {
20+
crossEntropy.calculate(predictions, targets);
21+
},
22+
{
23+
message: "Predictions and targets must have the same length.",
24+
},
25+
);
26+
});
27+
});

src/losses/cross_entropy.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* CrossEntropyLoss calculates the cross-entropy loss between predictions and target values.
3+
* This loss is commonly used for classification tasks.
4+
*
5+
* Formula:
6+
* `L = -Σ(y_true * log(y_pred))`
7+
*
8+
* @example
9+
* ```typescript
10+
* const crossEntropy = new CrossEntropyLoss();
11+
* const predictions = [0.7, 0.2, 0.1];
12+
* const targets = [1, 0, 0];
13+
* const loss = crossEntropy.calculate(predictions, targets);
14+
* console.log("CrossEntropy Loss:", loss); // Output: ~0.3567
15+
* ```
16+
*/
17+
export class CrossEntropyLoss {
18+
/**
19+
* Calculates the cross-entropy loss.
20+
* @param predictions An array of predicted probabilities (must sum to 1).
21+
* @param targets An array of one-hot encoded target values.
22+
* @returns The calculated cross-entropy loss, summed across all samples.
23+
* @throws Error if the predictions and targets arrays do not have the same length.
24+
*/
25+
calculate(predictions: number[], targets: number[]): number {
26+
if (predictions.length !== targets.length) {
27+
throw new Error("Predictions and targets must have the same length.");
28+
}
29+
30+
let loss = 0;
31+
for (let i = 0; i < predictions.length; i++) {
32+
if (targets[i] === 1) {
33+
// Avoid log(0) by adding a small epsilon
34+
const epsilon = 1e-12;
35+
loss -= Math.log(predictions[i] + epsilon);
36+
}
37+
}
38+
39+
// Normalize the loss by the number of samples
40+
return loss / predictions.length;
41+
}
42+
}

src/losses/mod.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
export * from "./mse.ts";
2+
export * from "./cross_entropy.ts";
3+
export * from "./binary_cross_entropy.ts";

0 commit comments

Comments
 (0)