Skip to content

Commit 59e3943

Browse files
Add more tests
1 parent a4634b8 commit 59e3943

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

probability/exponential_dist.cpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,19 @@
1212
* \f$\lambda\f$ : rate parameter
1313
*/
1414

15-
#include <cassert> // For assert
16-
#include <cmath> // For std::pow
17-
#include <iostream> // For I/O operation
15+
#include <cassert> // For assert
16+
#include <cmath> // For std::pow
17+
#include <iostream> // For I/O operation
18+
#include <stdexcept> // For std::invalid_argument
19+
#include <string> // For std::string
1820

1921
/**
2022
* @brief the expected value of the exponential distribution
2123
* @returns \f[\mu = \frac{1}{\lambda}\f]
2224
*/
2325
double exponential_expected(double lambda) {
2426
if (lambda <= 0) {
25-
std::cout << "Error: Lambda must be greater than 0." << '\n';
26-
assert(lambda > 0);
27+
throw std::invalid_argument("lambda must be greater than 0");
2728
}
2829
return 1 / lambda;
2930
}
@@ -34,8 +35,7 @@ double exponential_expected(double lambda) {
3435
*/
3536
double exponential_var(double lambda) {
3637
if (lambda <= 0) {
37-
std::cout << "Error: Lambda must be greater than 0." << '\n';
38-
assert(lambda > 0);
38+
throw std::invalid_argument("lambda must be greater than 0");
3939
}
4040
return 1 / pow(lambda, 2);
4141
}
@@ -46,8 +46,7 @@ double exponential_var(double lambda) {
4646
*/
4747
double exponential_std(double lambda) {
4848
if (lambda <= 0) {
49-
std::cout << "Error: Lambda must be greater than 0." << '\n';
50-
assert(lambda > 0);
49+
throw std::invalid_argument("lambda must be greater than 0");
5150
}
5251
return 1 / lambda;
5352
}
@@ -72,6 +71,9 @@ static void test() {
7271
double var_3 = 0.111111;
7372
double std_3 = 0.333333;
7473

74+
double lambda_4 = 0; // Test 0
75+
double lambda_5 = -2.3; // Test negative value
76+
7577
const float threshold = 1e-3f;
7678

7779
std::cout << "Test for lambda = 1 \n";
@@ -90,7 +92,27 @@ static void test() {
9092
assert(std::abs(expected_3 - exponential_expected(lambda_3)) < threshold);
9193
assert(std::abs(var_3 - exponential_var(lambda_3)) < threshold);
9294
assert(std::abs(std_3 - exponential_std(lambda_3)) < threshold);
93-
std::cout << "ALL TEST PASSED\n";
95+
std::cout << "ALL TEST PASSED\n\n";
96+
97+
std::cout << "Test for lambda = 0 \n";
98+
try {
99+
exponential_expected(lambda_4);
100+
exponential_var(lambda_4);
101+
exponential_std(lambda_4);
102+
} catch (std::invalid_argument& err) {
103+
assert(std::string(err.what()) == "lambda must be greater than 0");
104+
}
105+
std::cout << "ALL TEST PASSED\n\n";
106+
107+
std::cout << "Test for lambda = -2.3 \n";
108+
try {
109+
exponential_expected(lambda_5);
110+
exponential_var(lambda_5);
111+
exponential_std(lambda_5);
112+
} catch (std::invalid_argument& err) {
113+
assert(std::string(err.what()) == "lambda must be greater than 0");
114+
}
115+
std::cout << "ALL TEST PASSED\n\n";
94116
}
95117

96118
/**

0 commit comments

Comments
 (0)