@@ -116,15 +116,12 @@ class NCRModuloP {
116
116
return 1 ;
117
117
}
118
118
// fac is a global array with fac[r] = (r! % p)
119
- int64_t denominator = utils::modInverse (fac[r], p);
120
- if (denominator < 0 ) { // modular inverse doesn't exist
119
+ const auto denominator = (fac[r] * fac[n - r]) % p;
120
+ const auto denominator_inv = utils::modInverse (denominator, p);
121
+ if (denominator_inv < 0 ) { // modular inverse doesn't exist
121
122
return -1 ;
122
123
}
123
- denominator = (denominator * utils::modInverse (fac[n - r], p)) % p;
124
- if (denominator < 0 ) { // modular inverse doesn't exist
125
- return -1 ;
126
- }
127
- return (fac[n] * denominator) % p;
124
+ return (fac[n] * denominator_inv) % p;
128
125
}
129
126
};
130
127
} // namespace ncr_modulo_p
@@ -156,7 +153,8 @@ static void tests() {
156
153
TestCase (20 , 17 , 1 , 10 , 0 ),
157
154
TestCase (45 , 19 , 23 , 1 , 23 % 19 ),
158
155
TestCase (45 , 19 , 23 , 0 , 1 ),
159
- TestCase (45 , 19 , 23 , 23 , 1 )};
156
+ TestCase (45 , 19 , 23 , 23 , 1 ),
157
+ TestCase (20 , 9 , 10 , 2 , -1 )};
160
158
for (const auto & tc : test_cases) {
161
159
assert (math::ncr_modulo_p::NCRModuloP (tc.size , tc.p ).ncr (tc.n , tc.r ) ==
162
160
tc.expected );
0 commit comments