Skip to content

Commit 07b91b4

Browse files
committed
Add Simple Linear Regression Algorithm (TheAlgorithms#538)
1 parent 71a7733 commit 07b91b4

File tree

5 files changed

+191
-6
lines changed

5 files changed

+191
-6
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
using Algorithms.MachineLearning;
2+
3+
namespace Algorithms.Tests.MachineLearning;
4+
5+
/// <summary>
6+
/// Unit tests for the LinearRegression class.
7+
/// </summary>
8+
public class LinearRegressionTests
9+
{
10+
[Test]
11+
public void Fit_ThrowsException_WhenInputIsNull()
12+
{
13+
var lr = new LinearRegression();
14+
Assert.Throws<ArgumentException>(() => lr.Fit(null!, new List<double> { 1 }));
15+
Assert.Throws<ArgumentException>(() => lr.Fit(new List<double> { 1 }, null!));
16+
}
17+
18+
[Test]
19+
public void Fit_ThrowsException_WhenInputIsEmpty()
20+
{
21+
var lr = new LinearRegression();
22+
Assert.Throws<ArgumentException>(() => lr.Fit(new List<double>(), new List<double>()));
23+
}
24+
25+
[Test]
26+
public void Fit_ThrowsException_WhenInputLengthsDiffer()
27+
{
28+
var lr = new LinearRegression();
29+
Assert.Throws<ArgumentException>(() => lr.Fit(new List<double> { 1 }, new List<double> { 2, 3 }));
30+
}
31+
32+
[Test]
33+
public void Fit_ThrowsException_WhenXVarianceIsZero()
34+
{
35+
var lr = new LinearRegression();
36+
Assert.Throws<ArgumentException>(() => lr.Fit(new List<double> { 1, 1, 1 }, new List<double> { 2, 3, 4 }));
37+
}
38+
39+
[Test]
40+
public void Predict_ThrowsException_IfNotFitted()
41+
{
42+
var lr = new LinearRegression();
43+
Assert.Throws<InvalidOperationException>(() => lr.Predict(1.0));
44+
Assert.Throws<InvalidOperationException>(() => lr.Predict(new List<double> { 1.0 }));
45+
}
46+
47+
[Test]
48+
public void FitAndPredict_WorksForSimpleData()
49+
{
50+
// y = 2x + 1
51+
var x = new List<double> { 1, 2, 3, 4 };
52+
var y = new List<double> { 3, 5, 7, 9 };
53+
var lr = new LinearRegression();
54+
lr.Fit(x, y);
55+
Assert.That(lr.IsFitted, Is.True);
56+
Assert.That(lr.Intercept, Is.EqualTo(1.0).Within(1e-6));
57+
Assert.That(lr.Slope, Is.EqualTo(2.0).Within(1e-6));
58+
Assert.That(lr.Predict(5), Is.EqualTo(11.0).Within(1e-6));
59+
}
60+
61+
[Test]
62+
public void FitAndPredict_WorksForNegativeSlope()
63+
{
64+
// y = -3x + 4
65+
var x = new List<double> { 0, 1, 2 };
66+
var y = new List<double> { 4, 1, -2 };
67+
var lr = new LinearRegression();
68+
lr.Fit(x, y);
69+
Assert.That(lr.Intercept, Is.EqualTo(4.0).Within(1e-6));
70+
Assert.That(lr.Slope, Is.EqualTo(-3.0).Within(1e-6));
71+
Assert.That(lr.Predict(3), Is.EqualTo(-5.0).Within(1e-6));
72+
}
73+
74+
[Test]
75+
public void Predict_List_WorksCorrectly()
76+
{
77+
var x = new List<double> { 1, 2, 3 };
78+
var y = new List<double> { 2, 4, 6 };
79+
var lr = new LinearRegression();
80+
lr.Fit(x, y); // y = 2x
81+
var predictions = lr.Predict(new List<double> { 4, 5 });
82+
Assert.That(predictions[0], Is.EqualTo(8.0).Within(1e-6));
83+
Assert.That(predictions[1], Is.EqualTo(10.0).Within(1e-6));
84+
}
85+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
5+
namespace Algorithms.MachineLearning;
6+
7+
/// <summary>
8+
/// Implements simple linear regression for one independent variable (univariate).
9+
/// Linear regression is a supervised learning algorithm used to model the relationship
10+
/// between a scalar dependent variable (Y) and an independent variable (X).
11+
/// The model fits a line: Y = a + bX, where 'a' is the intercept and 'b' is the slope.
12+
/// </summary>
13+
public class LinearRegression
14+
{
15+
// Intercept (a) and slope (b) of the fitted line
16+
public double Intercept { get; private set; }
17+
18+
public double Slope { get; private set; }
19+
20+
public bool IsFitted { get; private set; }
21+
22+
/// <summary>
23+
/// Fits the linear regression model to the provided data.
24+
/// </summary>
25+
/// <param name="x">List of independent variable values.</param>
26+
/// <param name="y">List of dependent variable values.</param>
27+
/// <exception cref="ArgumentException">Thrown if input lists are null, empty, or of different lengths.</exception>
28+
public void Fit(IList<double> x, IList<double> y)
29+
{
30+
if (x == null || y == null)
31+
{
32+
throw new ArgumentException("Input data cannot be null.");
33+
}
34+
35+
if (x.Count == 0 || y.Count == 0)
36+
{
37+
throw new ArgumentException("Input data cannot be empty.");
38+
}
39+
40+
if (x.Count != y.Count)
41+
{
42+
throw new ArgumentException("Input lists must have the same length.");
43+
}
44+
45+
// Calculate means
46+
double xMean = x.Average();
47+
double yMean = y.Average();
48+
49+
// Calculate slope (b) and intercept (a)
50+
double numerator = 0.0;
51+
double denominator = 0.0;
52+
for (int i = 0; i < x.Count; i++)
53+
{
54+
numerator += (x[i] - xMean) * (y[i] - yMean);
55+
denominator += (x[i] - xMean) * (x[i] - xMean);
56+
}
57+
58+
const double epsilon = 1e-12;
59+
if (Math.Abs(denominator) < epsilon)
60+
{
61+
throw new ArgumentException("Variance of X must not be zero.");
62+
}
63+
64+
Slope = numerator / denominator;
65+
Intercept = yMean - Slope * xMean;
66+
IsFitted = true;
67+
}
68+
69+
/// <summary>
70+
/// Predicts the output value for a given input using the fitted model.
71+
/// </summary>
72+
/// <param name="x">Input value.</param>
73+
/// <returns>Predicted output value.</returns>
74+
/// <exception cref="InvalidOperationException">Thrown if the model is not fitted.</exception>
75+
public double Predict(double x)
76+
{
77+
if (!IsFitted)
78+
{
79+
throw new InvalidOperationException("Model must be fitted before prediction.");
80+
}
81+
82+
return Intercept + Slope * x;
83+
}
84+
85+
/// <summary>
86+
/// Predicts output values for a list of inputs using the fitted model.
87+
/// </summary>
88+
/// <param name="xValues">List of input values.</param>
89+
/// <returns>List of predicted output values.</returns>
90+
/// <exception cref="InvalidOperationException">Thrown if the model is not fitted.</exception>
91+
public IList<double> Predict(IList<double> xValues)
92+
{
93+
if (!IsFitted)
94+
{
95+
throw new InvalidOperationException("Model must be fitted before prediction.");
96+
}
97+
98+
return xValues.Select(Predict).ToList();
99+
}
100+
}

Algorithms/Problems/JobScheduling/IntervalSchedulingSolver.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
15
namespace Algorithms.Problems.JobScheduling;
26

37
/// <summary>

Algorithms/Problems/JobScheduling/Job.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Linq;
4-
51
namespace Algorithms.Problems.JobScheduling;
62

73
/// <summary>

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ find more than one implementation for the same objective but using different alg
106106
* [SoftMax Function](./Algorithms/Numeric/SoftMax.cs)
107107
* [RecommenderSystem](./Algorithms/RecommenderSystem)
108108
* [CollaborativeFiltering](./Algorithms/RecommenderSystem/CollaborativeFiltering)
109+
* [Machine Learning](./Algorithms/MachineLearning)
110+
* [Linear Regression](./Algorithms/MachineLearning/LinearRegression.cs)
109111
* [Searches](./Algorithms/Search)
110112
* [A-Star](./Algorithms/Search/AStar/)
111113
* [Binary Search](./Algorithms/Search/BinarySearcher.cs)
@@ -247,8 +249,6 @@ find more than one implementation for the same objective but using different alg
247249
* [Levenshtein Distance](./Algorithms/Problems/DynamicProgramming/LevenshteinDistance/LevenshteinDistance.cs)
248250
* [Job Scheduling](./Algorithms/Problems/JobScheduling)
249251
* [Interval Scheduling (Greedy)](./Algorithms/Problems/JobScheduling/IntervalSchedulingSolver.cs)
250-
* [Weighted Interval Scheduling (DP)](./Algorithms/Problems/JobScheduling/WeightedIntervalSchedulingSolver.cs)
251-
* **Applications:** Work schedule management, room booking, production optimization, and more.
252252

253253
* [Data Structures](./DataStructures)
254254
* [Bag](./DataStructures/Bag)

0 commit comments

Comments
 (0)