|
| 1 | +// Copyright 2010-2024 Google LLC |
| 2 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +// you may not use this file except in compliance with the License. |
| 4 | +// You may obtain a copy of the License at |
| 5 | +// |
| 6 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +// |
| 8 | +// Unless required by applicable law or agreed to in writing, software |
| 9 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +// See the License for the specific language governing permissions and |
| 12 | +// limitations under the License. |
| 13 | + |
| 14 | +#include "ortools/algorithms/n_choose_k.h" |
| 15 | + |
| 16 | +#include <cmath> |
| 17 | +#include <cstdint> |
| 18 | +#include <limits> |
| 19 | +#include <vector> |
| 20 | + |
| 21 | +#include "absl/log/check.h" |
| 22 | +#include "absl/numeric/int128.h" |
| 23 | +#include "absl/status/status.h" |
| 24 | +#include "absl/status/statusor.h" |
| 25 | +#include "absl/strings/str_format.h" |
| 26 | +#include "absl/time/clock.h" |
| 27 | +#include "absl/time/time.h" |
| 28 | +#include "ortools/algorithms/binary_search.h" |
| 29 | +#include "ortools/base/logging.h" |
| 30 | +#include "ortools/base/mathutil.h" |
| 31 | + |
| 32 | +namespace operations_research { |
| 33 | +namespace { |
| 34 | +// This is the actual computation. It's in O(k). |
| 35 | +template <typename Int> |
| 36 | +Int InternalChoose(Int n, Int k) { |
| 37 | + DCHECK_LE(k, n - k); |
| 38 | + DCHECK_GT(k, 0); // Having k>0 lets us start with i=2 (small optimization). |
| 39 | + // We compute n * (n-1) * ... * (n-k+1) / k! in the best possible order to |
| 40 | + // guarantee exact results, while trying to avoid overflows. It's not |
| 41 | + // perfect: we finish with a division by k, which means that me may overflow |
| 42 | + // even if the result doesn't (by a factor of up to k). |
| 43 | + Int result = n; |
| 44 | + for (Int i = 2; i <= k; ++i) { |
| 45 | + result *= n + 1 - i; |
| 46 | + result /= i; // The product of i consecutive numbers is divisible by i!. |
| 47 | + } |
| 48 | + return result; |
| 49 | +} |
| 50 | + |
| 51 | +// This function precomputes the maximum N such that (N choose K) doesn't |
| 52 | +// overflow, for all K. |
| 53 | +// When `overflows_intermediate_computation` is true, "overflow" means |
| 54 | +// "some overflow happens inside InternalChoose<int64_t>()", and when it's false |
| 55 | +// it simply means "the result doesn't fit in an int64_t". |
| 56 | +// This is only used in contexts where K ≤ N-K, which implies N ≥ 2K, thus we |
| 57 | +// can stop when (2K Choose K) overflows, because at and beyond such K, |
| 58 | +// (N Choose K) will always overflow. In practice that happens for K=31 or 34 |
| 59 | +// depending on `overflows_intermediate_computation`. |
| 60 | +template <class Int> |
| 61 | +std::vector<Int> LastNThatDoesNotOverflowForAllK( |
| 62 | + bool overflows_intermediate_computation) { |
| 63 | + absl::Time start_time = absl::Now(); |
| 64 | + // Given the algorithm used in InternalChoose(), it's not hard to |
| 65 | + // find out when (N choose K) overflows an int64_t during its internal |
| 66 | + // computation: that's when (N choose K) > MAX_INT / k. |
| 67 | + |
| 68 | + // For K ≤ 2, we hardcode the values of the maximum N. That's because |
| 69 | + // the binary search done below uses MathUtil::LogCombinations, which only |
| 70 | + // works on int32_t, and that's problematic for the max N we get for K=2. |
| 71 | + // |
| 72 | + // For K=2, we want N(N-1) ≤ 2^num_digits, or N(N-1)/2 ≤ 2^num_digits if |
| 73 | + // !overflows_intermediate_computation, i.e. N(N-1) ≤ 2^(num_digits+1). |
| 74 | + // Then, when d is even, N(N-1) ≤ 2^d ⇔ N ≤ 2^(d/2), which is simple. |
| 75 | + // When d is odd, it's harder: N(N-1)≈(N-0.5)² and thus we get the bound |
| 76 | + // N ≤ pow(2.0, d/2)+0.5. |
| 77 | + const int bound_digits = std::numeric_limits<Int>::digits + |
| 78 | + (overflows_intermediate_computation ? 0 : 1); |
| 79 | + std::vector<Int> result = { |
| 80 | + std::numeric_limits<Int>::max(), // K=0 |
| 81 | + std::numeric_limits<Int>::max(), // K=1 |
| 82 | + bound_digits % 2 == 0 |
| 83 | + ? Int{1} << (bound_digits / 2) |
| 84 | + : static_cast<Int>( |
| 85 | + 0.5 + std::pow(2.0, 0.5 * std::numeric_limits<Int>::digits)), |
| 86 | + }; |
| 87 | + // We find the last N with binary search, for all K. We stop growing K |
| 88 | + // when (2*K Choose K) overflows. |
| 89 | + for (Int k = 3;; ++k) { |
| 90 | + const double max_log_comb = |
| 91 | + overflows_intermediate_computation |
| 92 | + ? std::numeric_limits<Int>::digits * std::log(2) - std::log(k) |
| 93 | + : std::numeric_limits<Int>::digits * std::log(2); |
| 94 | + result.push_back(BinarySearch<Int>( |
| 95 | + /*x_true*/ k, |
| 96 | + // x_false=X, X needs to be large enough so that X choose 3 overflows: |
| 97 | + // (X choose 3)≈(X-1)³/6, so we pick X = 2+6*2^(num_digits/3+1). |
| 98 | + /*x_false=*/ |
| 99 | + (static_cast<Int>( |
| 100 | + 2 + 6 * std::pow(2.0, std::numeric_limits<Int>::digits / 3 + 1))), |
| 101 | + [k, max_log_comb](Int n) { |
| 102 | + return MathUtil::LogCombinations(n, k) <= max_log_comb; |
| 103 | + })); |
| 104 | + if (result.back() < 2 * k) { |
| 105 | + result.pop_back(); |
| 106 | + break; |
| 107 | + } |
| 108 | + } |
| 109 | + // Some DCHECKs for int64_t, which should validate the general formulaes. |
| 110 | + if constexpr (std::numeric_limits<Int>::digits == 63) { |
| 111 | + DCHECK_EQ(result.size(), |
| 112 | + overflows_intermediate_computation |
| 113 | + ? 31 // 60 Choose 30 < 2^63/30 but 62 Choose 31 > 2^63/31. |
| 114 | + : 34); // 66 Choose 33 < 2^63 but 68 Choose 34 > 2^63. |
| 115 | + } |
| 116 | + VLOG(1) << "LastNThatDoesNotOverflowForAllK(): " << absl::Now() - start_time; |
| 117 | + return result; |
| 118 | +} |
| 119 | + |
| 120 | +template <typename Int> |
| 121 | +bool NChooseKIntermediateComputationOverflowsInt(Int n, Int k) { |
| 122 | + DCHECK_LE(k, n - k); |
| 123 | + static const auto* const result = |
| 124 | + new std::vector<Int>(LastNThatDoesNotOverflowForAllK<Int>( |
| 125 | + /*overflows_intermediate_computation=*/true)); |
| 126 | + return k < result->size() ? n > (*result)[k] : true; |
| 127 | +} |
| 128 | + |
| 129 | +template <typename Int> |
| 130 | +bool NChooseKResultOverflowsInt(Int n, Int k) { |
| 131 | + DCHECK_LE(k, n - k); |
| 132 | + static const auto* const result = |
| 133 | + new std::vector<Int>(LastNThatDoesNotOverflowForAllK<Int>( |
| 134 | + /*overflows_intermediate_computation=*/false)); |
| 135 | + return k < result->size() ? n > (*result)[k] : true; |
| 136 | +} |
| 137 | +} // namespace |
| 138 | + |
| 139 | +// NOTE(user): If performance ever matters, we could simply precompute and |
| 140 | +// store all (N choose K) that don't overflow, there aren't that many of them: |
| 141 | +// only a few tens of thousands, after removing simple cases like k ≤ 5. |
| 142 | +absl::StatusOr<int64_t> NChooseK(int64_t n, int64_t k) { |
| 143 | + if (n < 0) { |
| 144 | + return absl::InvalidArgumentError(absl::StrFormat("n is negative (%d)", n)); |
| 145 | + } |
| 146 | + if (k < 0) { |
| 147 | + return absl::InvalidArgumentError(absl::StrFormat("k is negative (%d)", k)); |
| 148 | + } |
| 149 | + if (k > n) { |
| 150 | + return absl::InvalidArgumentError( |
| 151 | + absl::StrFormat("k=%d is greater than n=%d", k, n)); |
| 152 | + } |
| 153 | + if (k > n / 2) k = n - k; |
| 154 | + if (k == 0) return 1; |
| 155 | + if (n < std::numeric_limits<uint32_t>::max() && |
| 156 | + !NChooseKIntermediateComputationOverflowsInt<uint32_t>(n, k)) { |
| 157 | + return static_cast<int64_t>(InternalChoose<uint32_t>(n, k)); |
| 158 | + } |
| 159 | + if (!NChooseKIntermediateComputationOverflowsInt<int64_t>(n, k)) { |
| 160 | + return InternalChoose<uint64_t>(n, k); |
| 161 | + } |
| 162 | + if (NChooseKResultOverflowsInt<int64_t>(n, k)) { |
| 163 | + return absl::InvalidArgumentError( |
| 164 | + absl::StrFormat("(%d choose %d) overflows int64", n, k)); |
| 165 | + } |
| 166 | + return static_cast<int64_t>(InternalChoose<absl::uint128>(n, k)); |
| 167 | +} |
| 168 | + |
| 169 | +} // namespace operations_research |
0 commit comments