Skip to content

Commit 3de094a

Browse files
committed
re-export n_choose_k
1 parent 9f86916 commit 3de094a

File tree

4 files changed

+562
-0
lines changed

4 files changed

+562
-0
lines changed

ortools/algorithms/BUILD.bazel

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,3 +534,39 @@ cc_test(
534534
"//ortools/base:gmock_main",
535535
],
536536
)
537+
538+
cc_library(
539+
name = "n_choose_k",
540+
srcs = ["n_choose_k.cc"],
541+
hdrs = ["n_choose_k.h"],
542+
deps = [
543+
":binary_search",
544+
"//ortools/base:mathutil",
545+
"@com_google_absl//absl/log",
546+
"@com_google_absl//absl/log:check",
547+
"@com_google_absl//absl/numeric:int128",
548+
"@com_google_absl//absl/status",
549+
"@com_google_absl//absl/status:statusor",
550+
"@com_google_absl//absl/strings:str_format",
551+
"@com_google_absl//absl/time",
552+
],
553+
)
554+
555+
cc_test(
556+
name = "n_choose_k_test",
557+
srcs = ["n_choose_k_test.cc"],
558+
deps = [
559+
":n_choose_k",
560+
"//ortools/base:dump_vars",
561+
"//ortools/base:fuzztest",
562+
"//ortools/base:gmock_main",
563+
"//ortools/base:mathutil",
564+
"//ortools/util:flat_matrix",
565+
"@com_google_absl//absl/numeric:int128",
566+
"@com_google_absl//absl/random",
567+
"@com_google_absl//absl/random:distributions",
568+
"@com_google_absl//absl/status",
569+
"@com_google_absl//absl/status:statusor",
570+
"@com_google_benchmark//:benchmark",
571+
],
572+
)

ortools/algorithms/n_choose_k.cc

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// Copyright 2010-2024 Google LLC
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
#include "ortools/algorithms/n_choose_k.h"
15+
16+
#include <cmath>
17+
#include <cstdint>
18+
#include <limits>
19+
#include <vector>
20+
21+
#include "absl/log/check.h"
22+
#include "absl/numeric/int128.h"
23+
#include "absl/status/status.h"
24+
#include "absl/status/statusor.h"
25+
#include "absl/strings/str_format.h"
26+
#include "absl/time/clock.h"
27+
#include "absl/time/time.h"
28+
#include "ortools/algorithms/binary_search.h"
29+
#include "ortools/base/logging.h"
30+
#include "ortools/base/mathutil.h"
31+
32+
namespace operations_research {
33+
namespace {
34+
// This is the actual computation. It's in O(k).
35+
template <typename Int>
36+
Int InternalChoose(Int n, Int k) {
37+
DCHECK_LE(k, n - k);
38+
DCHECK_GT(k, 0); // Having k>0 lets us start with i=2 (small optimization).
39+
// We compute n * (n-1) * ... * (n-k+1) / k! in the best possible order to
40+
// guarantee exact results, while trying to avoid overflows. It's not
41+
// perfect: we finish with a division by k, which means that me may overflow
42+
// even if the result doesn't (by a factor of up to k).
43+
Int result = n;
44+
for (Int i = 2; i <= k; ++i) {
45+
result *= n + 1 - i;
46+
result /= i; // The product of i consecutive numbers is divisible by i!.
47+
}
48+
return result;
49+
}
50+
51+
// This function precomputes the maximum N such that (N choose K) doesn't
52+
// overflow, for all K.
53+
// When `overflows_intermediate_computation` is true, "overflow" means
54+
// "some overflow happens inside InternalChoose<int64_t>()", and when it's false
55+
// it simply means "the result doesn't fit in an int64_t".
56+
// This is only used in contexts where K ≤ N-K, which implies N ≥ 2K, thus we
57+
// can stop when (2K Choose K) overflows, because at and beyond such K,
58+
// (N Choose K) will always overflow. In practice that happens for K=31 or 34
59+
// depending on `overflows_intermediate_computation`.
60+
template <class Int>
61+
std::vector<Int> LastNThatDoesNotOverflowForAllK(
62+
bool overflows_intermediate_computation) {
63+
absl::Time start_time = absl::Now();
64+
// Given the algorithm used in InternalChoose(), it's not hard to
65+
// find out when (N choose K) overflows an int64_t during its internal
66+
// computation: that's when (N choose K) > MAX_INT / k.
67+
68+
// For K ≤ 2, we hardcode the values of the maximum N. That's because
69+
// the binary search done below uses MathUtil::LogCombinations, which only
70+
// works on int32_t, and that's problematic for the max N we get for K=2.
71+
//
72+
// For K=2, we want N(N-1) ≤ 2^num_digits, or N(N-1)/2 ≤ 2^num_digits if
73+
// !overflows_intermediate_computation, i.e. N(N-1) ≤ 2^(num_digits+1).
74+
// Then, when d is even, N(N-1) ≤ 2^d ⇔ N ≤ 2^(d/2), which is simple.
75+
// When d is odd, it's harder: N(N-1)≈(N-0.5)² and thus we get the bound
76+
// N ≤ pow(2.0, d/2)+0.5.
77+
const int bound_digits = std::numeric_limits<Int>::digits +
78+
(overflows_intermediate_computation ? 0 : 1);
79+
std::vector<Int> result = {
80+
std::numeric_limits<Int>::max(), // K=0
81+
std::numeric_limits<Int>::max(), // K=1
82+
bound_digits % 2 == 0
83+
? Int{1} << (bound_digits / 2)
84+
: static_cast<Int>(
85+
0.5 + std::pow(2.0, 0.5 * std::numeric_limits<Int>::digits)),
86+
};
87+
// We find the last N with binary search, for all K. We stop growing K
88+
// when (2*K Choose K) overflows.
89+
for (Int k = 3;; ++k) {
90+
const double max_log_comb =
91+
overflows_intermediate_computation
92+
? std::numeric_limits<Int>::digits * std::log(2) - std::log(k)
93+
: std::numeric_limits<Int>::digits * std::log(2);
94+
result.push_back(BinarySearch<Int>(
95+
/*x_true*/ k,
96+
// x_false=X, X needs to be large enough so that X choose 3 overflows:
97+
// (X choose 3)≈(X-1)³/6, so we pick X = 2+6*2^(num_digits/3+1).
98+
/*x_false=*/
99+
(static_cast<Int>(
100+
2 + 6 * std::pow(2.0, std::numeric_limits<Int>::digits / 3 + 1))),
101+
[k, max_log_comb](Int n) {
102+
return MathUtil::LogCombinations(n, k) <= max_log_comb;
103+
}));
104+
if (result.back() < 2 * k) {
105+
result.pop_back();
106+
break;
107+
}
108+
}
109+
// Some DCHECKs for int64_t, which should validate the general formulaes.
110+
if constexpr (std::numeric_limits<Int>::digits == 63) {
111+
DCHECK_EQ(result.size(),
112+
overflows_intermediate_computation
113+
? 31 // 60 Choose 30 < 2^63/30 but 62 Choose 31 > 2^63/31.
114+
: 34); // 66 Choose 33 < 2^63 but 68 Choose 34 > 2^63.
115+
}
116+
VLOG(1) << "LastNThatDoesNotOverflowForAllK(): " << absl::Now() - start_time;
117+
return result;
118+
}
119+
120+
template <typename Int>
121+
bool NChooseKIntermediateComputationOverflowsInt(Int n, Int k) {
122+
DCHECK_LE(k, n - k);
123+
static const auto* const result =
124+
new std::vector<Int>(LastNThatDoesNotOverflowForAllK<Int>(
125+
/*overflows_intermediate_computation=*/true));
126+
return k < result->size() ? n > (*result)[k] : true;
127+
}
128+
129+
template <typename Int>
130+
bool NChooseKResultOverflowsInt(Int n, Int k) {
131+
DCHECK_LE(k, n - k);
132+
static const auto* const result =
133+
new std::vector<Int>(LastNThatDoesNotOverflowForAllK<Int>(
134+
/*overflows_intermediate_computation=*/false));
135+
return k < result->size() ? n > (*result)[k] : true;
136+
}
137+
} // namespace
138+
139+
// NOTE(user): If performance ever matters, we could simply precompute and
140+
// store all (N choose K) that don't overflow, there aren't that many of them:
141+
// only a few tens of thousands, after removing simple cases like k ≤ 5.
142+
absl::StatusOr<int64_t> NChooseK(int64_t n, int64_t k) {
143+
if (n < 0) {
144+
return absl::InvalidArgumentError(absl::StrFormat("n is negative (%d)", n));
145+
}
146+
if (k < 0) {
147+
return absl::InvalidArgumentError(absl::StrFormat("k is negative (%d)", k));
148+
}
149+
if (k > n) {
150+
return absl::InvalidArgumentError(
151+
absl::StrFormat("k=%d is greater than n=%d", k, n));
152+
}
153+
if (k > n / 2) k = n - k;
154+
if (k == 0) return 1;
155+
if (n < std::numeric_limits<uint32_t>::max() &&
156+
!NChooseKIntermediateComputationOverflowsInt<uint32_t>(n, k)) {
157+
return static_cast<int64_t>(InternalChoose<uint32_t>(n, k));
158+
}
159+
if (!NChooseKIntermediateComputationOverflowsInt<int64_t>(n, k)) {
160+
return InternalChoose<uint64_t>(n, k);
161+
}
162+
if (NChooseKResultOverflowsInt<int64_t>(n, k)) {
163+
return absl::InvalidArgumentError(
164+
absl::StrFormat("(%d choose %d) overflows int64", n, k));
165+
}
166+
return static_cast<int64_t>(InternalChoose<absl::uint128>(n, k));
167+
}
168+
169+
} // namespace operations_research

ortools/algorithms/n_choose_k.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2010-2024 Google LLC
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
#ifndef OR_TOOLS_ALGORITHMS_N_CHOOSE_K_H_
15+
#define OR_TOOLS_ALGORITHMS_N_CHOOSE_K_H_
16+
17+
#include <cstdint>
18+
19+
#include "absl/status/statusor.h"
20+
21+
namespace operations_research {
22+
// Returns the number of ways to choose k elements among n, ignoring the order,
23+
// i.e., the binomial coefficient (n, k).
24+
// This is like std::exp(MathUtil::LogCombinations(n, k)), but faster, with
25+
// perfect accuracy, and returning an error iff the result would overflow an
26+
// int64_t or if an argument is invalid (i.e., n < 0, k < 0, or k > n).
27+
//
28+
// NOTE(user): If you need a variation of this, ask the authors: it's very easy
29+
// to add. E.g., other int types, other behaviors (e.g., return 0 if k > n, or
30+
// std::numeric_limits<int64_t>::max() on overflow, etc).
31+
absl::StatusOr<int64_t> NChooseK(int64_t n, int64_t k);
32+
} // namespace operations_research
33+
34+
#endif // OR_TOOLS_ALGORITHMS_N_CHOOSE_K_H_

0 commit comments

Comments
 (0)