Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions Algorithms.Tests/MachineLearning/LogisticRegressionTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using NUnit.Framework;
using Algorithms.MachineLearning;
using System;

namespace Algorithms.Tests.MachineLearning;

[TestFixture]
public class LogisticRegressionTests
{
[Test]
public void Fit_ThrowsOnEmptyInput()
{
var model = new LogisticRegression();
Assert.Throws<ArgumentException>(() => model.Fit(Array.Empty<double[]>(), Array.Empty<int>()));
}

[Test]
public void Fit_ThrowsOnMismatchedLabels()
{
var model = new LogisticRegression();
double[][] X = { new double[] { 1, 2 } };
int[] y = { 1, 0 };
Assert.Throws<ArgumentException>(() => model.Fit(X, y));
}

[Test]
public void FitAndPredict_WorksOnSimpleData()
{
// Simple AND logic
double[][] X =
{
new[] { 0.0, 0.0 },
new[] { 0.0, 1.0 },
new[] { 1.0, 0.0 },
new[] { 1.0, 1.0 }
};
int[] y = { 0, 0, 0, 1 };
var model = new LogisticRegression();
model.Fit(X, y, epochs: 2000, learningRate: 0.1);
Assert.That(model.Predict(new double[] { 0, 0 }), Is.EqualTo(0));
Assert.That(model.Predict(new double[] { 0, 1 }), Is.EqualTo(0));
Assert.That(model.Predict(new double[] { 1, 0 }), Is.EqualTo(0));
Assert.That(model.Predict(new double[] { 1, 1 }), Is.EqualTo(1));
}

[Test]
public void PredictProbability_ThrowsOnFeatureMismatch()
{
var model = new LogisticRegression();
double[][] X = { new double[] { 1, 2 } };
int[] y = { 1 };
model.Fit(X, y);
Assert.Throws<ArgumentException>(() => model.PredictProbability(new double[] { 1 }));
}

[Test]
public void FeatureCount_ReturnsCorrectValue()
{
var model = new LogisticRegression();
double[][] X = { new double[] { 1, 2, 3 } };
int[] y = { 1 };
model.Fit(X, y);
Assert.That(model.FeatureCount, Is.EqualTo(3));
}
}
87 changes: 87 additions & 0 deletions Algorithms/MachineLearning/LogisticRegression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using System;
using System.Linq;

namespace Algorithms.MachineLearning;

/// <summary>
/// Logistic Regression for binary classification.
/// </summary>
public class LogisticRegression
{
private double[] weights = [];
private double bias;

public int FeatureCount => weights.Length;

/// <summary>
/// Fit the model using gradient descent.
/// </summary>
/// <param name="x">2D array of features (samples x features).</param>
/// <param name="y">Array of labels (0 or 1).</param>
/// <param name="epochs">Number of iterations.</param>
/// <param name="learningRate">Step size.</param>
public void Fit(double[][] x, int[] y, int epochs = 1000, double learningRate = 0.01)
{
if (x.Length == 0 || x[0].Length == 0)
{
throw new ArgumentException("Input features cannot be empty.");
}

if (x.Length != y.Length)
{
throw new ArgumentException("Number of samples and labels must match.");
}

int nSamples = x.Length;
int nFeatures = x[0].Length;
weights = new double[nFeatures];
bias = 0;

for (int epoch = 0; epoch < epochs; epoch++)
{
double[] dw = new double[nFeatures];
double db = 0;
for (int i = 0; i < nSamples; i++)
{
double linear = Dot(x[i], weights) + bias;
double pred = Sigmoid(linear);
double error = pred - y[i];
for (int j = 0; j < nFeatures; j++)
{
dw[j] += error * x[i][j];
}

db += error;
}

for (int j = 0; j < nFeatures; j++)
{
weights[j] -= learningRate * dw[j] / nSamples;
}

bias -= learningRate * db / nSamples;
}
}

/// <summary>
/// Predict probability for a single sample.
/// </summary>
public double PredictProbability(double[] x)
{
if (x.Length != weights.Length)
{
throw new ArgumentException("Feature count mismatch.");
}

return Sigmoid(Dot(x, weights) + bias);
}

/// <summary>
/// Predict class label (0 or 1) for a single sample.
/// </summary>
public int Predict(double[] x) => PredictProbability(x) >= 0.5 ? 1 : 0;

private static double Sigmoid(double z) => 1.0 / (1.0 + Math.Exp(-z));

private static double Dot(double[] a, double[] b) => a.Zip(b).Sum(pair => pair.First * pair.Second);
}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ find more than one implementation for the same objective but using different alg
* [CollaborativeFiltering](./Algorithms/RecommenderSystem/CollaborativeFiltering)
* [Machine Learning](./Algorithms/MachineLearning)
* [Linear Regression](./Algorithms/MachineLearning/LinearRegression.cs)
* [Logistic Regression](./Algorithms/MachineLearning/LogisticRegression.cs)
* [Searches](./Algorithms/Search)
* [A-Star](./Algorithms/Search/AStar/)
* [Binary Search](./Algorithms/Search/BinarySearcher.cs)
Expand Down