Skip to content

Commit d390d0d

Browse files
committed
Add Logistic Regression Algorithm for Machine Learning
1 parent e2c20ed commit d390d0d

File tree

3 files changed

+157
-0
lines changed

3 files changed

+157
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using NUnit.Framework;
2+
using Algorithms.MachineLearning;
3+
using System;
4+
5+
namespace Algorithms.Tests.MachineLearning;
6+
7+
[TestFixture]
8+
public class LogisticRegressionTests
9+
{
10+
[Test]
11+
public void Fit_ThrowsOnEmptyInput()
12+
{
13+
var model = new LogisticRegression();
14+
Assert.Throws<ArgumentException>(() => model.Fit(Array.Empty<double[]>(), Array.Empty<int>()));
15+
}
16+
17+
[Test]
18+
public void Fit_ThrowsOnMismatchedLabels()
19+
{
20+
var model = new LogisticRegression();
21+
double[][] X = { new double[] { 1, 2 } };
22+
int[] y = { 1, 0 };
23+
Assert.Throws<ArgumentException>(() => model.Fit(X, y));
24+
}
25+
26+
[Test]
27+
public void FitAndPredict_WorksOnSimpleData()
28+
{
29+
// Simple AND logic
30+
double[][] X =
31+
{
32+
new[] { 0.0, 0.0 },
33+
new[] { 0.0, 1.0 },
34+
new[] { 1.0, 0.0 },
35+
new[] { 1.0, 1.0 }
36+
};
37+
int[] y = { 0, 0, 0, 1 };
38+
var model = new LogisticRegression();
39+
model.Fit(X, y, epochs: 2000, learningRate: 0.1);
40+
Assert.That(model.Predict(new double[] { 0, 0 }), Is.EqualTo(0));
41+
Assert.That(model.Predict(new double[] { 0, 1 }), Is.EqualTo(0));
42+
Assert.That(model.Predict(new double[] { 1, 0 }), Is.EqualTo(0));
43+
Assert.That(model.Predict(new double[] { 1, 1 }), Is.EqualTo(1));
44+
}
45+
46+
[Test]
47+
public void PredictProbability_ThrowsOnFeatureMismatch()
48+
{
49+
var model = new LogisticRegression();
50+
double[][] X = { new double[] { 1, 2 } };
51+
int[] y = { 1 };
52+
model.Fit(X, y);
53+
Assert.Throws<ArgumentException>(() => model.PredictProbability(new double[] { 1 }));
54+
}
55+
56+
[Test]
57+
public void FeatureCount_ReturnsCorrectValue()
58+
{
59+
var model = new LogisticRegression();
60+
double[][] X = { new double[] { 1, 2, 3 } };
61+
int[] y = { 1 };
62+
model.Fit(X, y);
63+
Assert.That(model.FeatureCount, Is.EqualTo(3));
64+
}
65+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
using System;
2+
using System.Linq;
3+
4+
namespace Algorithms.MachineLearning;
5+
6+
/// <summary>
7+
/// Logistic Regression for binary classification.
8+
/// </summary>
9+
public class LogisticRegression
10+
{
11+
private double[] weights = [];
12+
private double bias = 0;
13+
14+
public int FeatureCount => weights.Length;
15+
16+
public LogisticRegression()
17+
{
18+
}
19+
20+
/// <summary>
21+
/// Fit the model using gradient descent.
22+
/// </summary>
23+
/// <param name="x">2D array of features (samples x features).</param>
24+
/// <param name="y">Array of labels (0 or 1).</param>
25+
/// <param name="epochs">Number of iterations.</param>
26+
/// <param name="learningRate">Step size.</param>
27+
public void Fit(double[][] x, int[] y, int epochs = 1000, double learningRate = 0.01)
28+
{
29+
if (x.Length == 0 || x[0].Length == 0)
30+
{
31+
throw new ArgumentException("Input features cannot be empty.");
32+
}
33+
34+
if (x.Length != y.Length)
35+
{
36+
throw new ArgumentException("Number of samples and labels must match.");
37+
}
38+
39+
int nSamples = x.Length;
40+
int nFeatures = x[0].Length;
41+
weights = new double[nFeatures];
42+
bias = 0;
43+
44+
for (int epoch = 0; epoch < epochs; epoch++)
45+
{
46+
double[] dw = new double[nFeatures];
47+
double db = 0;
48+
for (int i = 0; i < nSamples; i++)
49+
{
50+
double linear = Dot(x[i], weights) + bias;
51+
double pred = Sigmoid(linear);
52+
double error = pred - y[i];
53+
for (int j = 0; j < nFeatures; j++)
54+
{
55+
dw[j] += error * x[i][j];
56+
}
57+
58+
db += error;
59+
}
60+
61+
for (int j = 0; j < nFeatures; j++)
62+
{
63+
weights[j] -= learningRate * dw[j] / nSamples;
64+
}
65+
66+
bias -= learningRate * db / nSamples;
67+
}
68+
}
69+
70+
/// <summary>
71+
/// Predict probability for a single sample.
72+
/// </summary>
73+
public double PredictProbability(double[] x)
74+
{
75+
if (x.Length != weights.Length)
76+
{
77+
throw new ArgumentException("Feature count mismatch.");
78+
}
79+
80+
return Sigmoid(Dot(x, weights) + bias);
81+
}
82+
83+
/// <summary>
84+
/// Predict class label (0 or 1) for a single sample.
85+
/// </summary>
86+
public int Predict(double[] x) => PredictProbability(x) >= 0.5 ? 1 : 0;
87+
88+
private static double Sigmoid(double z) => 1.0 / (1.0 + Math.Exp(-z));
89+
90+
private static double Dot(double[] a, double[] b) => a.Zip(b).Sum(pair => pair.First * pair.Second);
91+
}

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ find more than one implementation for the same objective but using different alg
108108
* [CollaborativeFiltering](./Algorithms/RecommenderSystem/CollaborativeFiltering)
109109
* [Machine Learning](./Algorithms/MachineLearning)
110110
* [Linear Regression](./Algorithms/MachineLearning/LinearRegression.cs)
111+
* [Logistic Regression](./Algorithms/MachineLearning/LogisticRegression.cs)
111112
* [Searches](./Algorithms/Search)
112113
* [A-Star](./Algorithms/Search/AStar/)
113114
* [Binary Search](./Algorithms/Search/BinarySearcher.cs)

0 commit comments

Comments
 (0)