Skip to content

Commit e4a6dd2

Browse files
committed
implement uint128 on msvc
1 parent 58f7d10 commit e4a6dd2

File tree

5 files changed

+449
-49
lines changed

5 files changed

+449
-49
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ jrl_target_headers(proxsuite INTERFACE
323323
include/proxsuite/proxqp/sparse/wrapper.hpp
324324
include/proxsuite/proxqp/utils/prints.hpp
325325
include/proxsuite/proxqp/utils/random_qp_problems.hpp
326+
include/proxsuite/proxqp/utils/uint128_msvc.hpp
326327
include/proxsuite/proxqp/results.hpp
327328
include/proxsuite/proxqp/settings.hpp
328329
include/proxsuite/proxqp/status.hpp

include/proxsuite/proxqp/utils/random_qp_problems.hpp

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
#include <map>
1414
#include <random>
1515

16+
#if defined(_MSC_VER)
17+
#include <proxsuite/proxqp/utils/uint128_msvc.hpp>
18+
#endif
19+
1620
namespace proxsuite {
1721
namespace proxqp {
1822
namespace utils {
@@ -67,56 +71,11 @@ namespace rand {
6771
using proxqp::u32;
6872
using proxqp::u64;
6973

70-
#ifdef _MSC_VER
71-
/* Using the MSCV compiler on Windows causes problems because the type uint128
72-
is not available. Therefore, we use a random number generator from the stdlib
73-
instead of our custom Lehmer random number generator. The necessary lehmer
74-
functions used in in our code are remplaced with calls to the stdlib.*/
75-
inline auto
76-
get_gen() -> std::mt19937&
77-
{
78-
static std::mt19937 gen(1234);
79-
return gen;
80-
}
81-
inline auto
82-
get_uniform_dist() -> std::uniform_real_distribution<>&
83-
{
84-
static std::uniform_real_distribution<> uniform_dist(0.0, 1.0);
85-
return uniform_dist;
86-
}
87-
inline auto
88-
get_normal_dist() -> std::normal_distribution<double>&
89-
{
90-
static std::normal_distribution<double> normal_dist;
91-
return normal_dist;
92-
}
93-
using u128 = u64;
94-
inline auto
95-
uniform_rand() -> double
96-
{
97-
double output = double(get_uniform_dist()(get_gen()));
98-
return output;
99-
}
100-
inline auto
101-
lehmer_global() -> u128&
102-
{
103-
static u64 output = get_gen()();
104-
return output;
105-
}
106-
107-
inline void
108-
set_seed(u64 seed)
109-
{
110-
get_gen().seed(seed);
111-
}
112-
113-
inline auto
114-
normal_rand() -> double
115-
{
116-
return get_normal_dist()(get_gen());
117-
}
74+
#if defined(_MSC_VER)
75+
using u128 = uint128_t;
11876
#else
11977
using u128 = __uint128_t;
78+
#endif
12079

12180
constexpr u128 lehmer64_constant(0xda942042e4dd58b5);
12281
inline auto
@@ -160,7 +119,6 @@ normal_rand() -> double
160119

161120
return sqrt * std::cos(pi2 * u2);
162121
}
163-
#endif
164122

165123
template<typename Scalar>
166124
auto
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
#pragma once
2+
3+
#if !defined(_MSC_VER)
4+
#error "This file is only compatible with the MSVC compiler"
5+
#endif
6+
7+
#include <cstdint>
8+
#include <immintrin.h>
9+
10+
class uint128_t
11+
{
12+
public:
13+
uint64_t low;
14+
uint64_t high;
15+
16+
// --- Constructors ---
17+
constexpr uint128_t()
18+
: low(0)
19+
, high(0)
20+
{
21+
}
22+
constexpr uint128_t(uint64_t l)
23+
: low(l)
24+
, high(0)
25+
{
26+
}
27+
constexpr uint128_t(uint64_t l, uint64_t h)
28+
: low(l)
29+
, high(h)
30+
{
31+
}
32+
33+
// --- Type Conversions ---
34+
explicit operator bool() const { return low || high; }
35+
explicit operator uint64_t() const { return low; }
36+
explicit operator int64_t() const { return static_cast<int64_t>(low); }
37+
38+
// --- Arithmetic Operators ---
39+
40+
// Addition
41+
uint128_t operator+(const uint128_t& rhs) const
42+
{
43+
uint128_t result;
44+
unsigned char carry = _addcarry_u64(0, low, rhs.low, &result.low);
45+
_addcarry_u64(carry, high, rhs.high, &result.high);
46+
return result;
47+
}
48+
49+
uint128_t& operator+=(const uint128_t& rhs)
50+
{
51+
*this = *this + rhs;
52+
return *this;
53+
}
54+
55+
// Subtraction
56+
uint128_t operator-(const uint128_t& rhs) const
57+
{
58+
uint128_t result;
59+
unsigned char borrow = _subborrow_u64(0, low, rhs.low, &result.low);
60+
_subborrow_u64(borrow, high, rhs.high, &result.high);
61+
return result;
62+
}
63+
64+
uint128_t& operator-=(const uint128_t& rhs)
65+
{
66+
*this = *this - rhs;
67+
return *this;
68+
}
69+
70+
// Multiplication
71+
uint128_t operator*(const uint128_t& rhs) const
72+
{
73+
uint64_t product_high;
74+
uint64_t product_low = _umul128(low, rhs.low, &product_high);
75+
76+
// The total high part is the high part of (low * rhs.low)
77+
// plus the cross terms (low * rhs.high) and (high * rhs.low)
78+
product_high += (low * rhs.high) + (high * rhs.low);
79+
80+
return uint128_t(product_low, product_high);
81+
}
82+
83+
uint128_t& operator*=(const uint128_t& rhs)
84+
{
85+
*this = *this * rhs;
86+
return *this;
87+
}
88+
89+
// Division (Note: Full 128-bit division is complex to implement purely with
90+
// intrinsics if the divisor is > 64 bits. This is a simplified version
91+
// handling common cases). For production-grade full 128/128 division, usage
92+
// of a library like Boost is strongly advised. However, if divisor fits in 64
93+
// bits, we can use _udiv128.
94+
uint128_t operator/(const uint128_t& rhs) const
95+
{
96+
if (rhs.high == 0) {
97+
// Optimization for 64-bit divisor
98+
uint64_t remainder;
99+
uint64_t quotient_high = 0; // High part of result
100+
uint64_t quotient_low;
101+
102+
// If our high part is distinct, we divide the high part first
103+
if (high > 0) {
104+
// This is slightly tricky with _udiv128 directly as it does 128/64
105+
// -> 64. Standard long division algorithm is safer here for the general
106+
// implementation. For simplicity in this snippet, we will fallback to a
107+
// naive loop or simple approximation OR promote strictly the 64-bit
108+
// divisor case which is most common:
109+
110+
quotient_high = high / rhs.low;
111+
uint64_t r_high = high % rhs.low;
112+
113+
quotient_low = _udiv128(r_high, low, rhs.low, &remainder);
114+
return uint128_t(quotient_low, quotient_high);
115+
} else {
116+
return uint128_t(low / rhs.low, 0);
117+
}
118+
}
119+
// Fallback for full 128-bit divisor: Very slow basic binary long division
120+
if (rhs > *this)
121+
return uint128_t(0);
122+
if (rhs == *this)
123+
return uint128_t(1);
124+
125+
uint128_t temp = *this;
126+
uint128_t quot = 0;
127+
uint128_t one = 1;
128+
129+
// This is slow O(N) division, acceptable for simple utility, bad for heavy
130+
// math
131+
while (temp >= rhs) {
132+
// Find shift
133+
uint128_t shift_rhs = rhs;
134+
uint128_t shift_count = 1;
135+
while ((shift_rhs.high & 0x8000000000000000) == 0 &&
136+
(shift_rhs << 1) <= temp) {
137+
shift_rhs <<= 1;
138+
shift_count <<= 1;
139+
}
140+
temp -= shift_rhs;
141+
quot += shift_count;
142+
}
143+
return quot;
144+
}
145+
146+
// Modulus
147+
uint128_t operator%(const uint128_t& rhs) const
148+
{
149+
return *this - (*this / rhs) * rhs;
150+
}
151+
152+
uint128_t& operator%=(const uint128_t& rhs)
153+
{
154+
*this = *this % rhs;
155+
return *this;
156+
}
157+
158+
// --- Bitwise Operators ---
159+
uint128_t operator<<(int shift) const
160+
{
161+
shift &= 127; // Mask the shift amount to imitate native hardware behavior
162+
// (modulo 128)
163+
if (shift == 0)
164+
return *this;
165+
if (shift >= 64) {
166+
return uint128_t(0, low << (shift - 64));
167+
}
168+
return uint128_t((low << shift), (high << shift) | (low >> (64 - shift)));
169+
}
170+
171+
uint128_t operator>>(int shift) const
172+
{
173+
shift &= 127; // Mask the shift amount to imitate native hardware behavior
174+
// (modulo 128)
175+
if (shift == 0)
176+
return *this;
177+
if (shift >= 64) {
178+
return uint128_t(high >> (shift - 64), 0);
179+
}
180+
return uint128_t((low >> shift) | (high << (64 - shift)), (high >> shift));
181+
}
182+
183+
// --- Shift by uint128_t Overloads ---
184+
uint128_t operator>>(const uint128_t& shift) const
185+
{
186+
// If shift amount is >= 128, the result behavior mimics hardware (modulo
187+
// 128)
188+
return *this >> static_cast<int>(shift.low);
189+
}
190+
191+
uint128_t operator<<(const uint128_t& shift) const
192+
{
193+
// If shift amount is >= 128, the result behavior mimics hardware (modulo
194+
// 128)
195+
return *this << static_cast<int>(shift.low);
196+
}
197+
198+
uint128_t& operator<<=(int shift)
199+
{
200+
*this = *this << shift;
201+
return *this;
202+
}
203+
uint128_t& operator>>=(int shift)
204+
{
205+
*this = *this >> shift;
206+
return *this;
207+
}
208+
209+
uint128_t operator|(const uint128_t& rhs) const
210+
{
211+
return uint128_t(low | rhs.low, high | rhs.high);
212+
}
213+
uint128_t operator&(const uint128_t& rhs) const
214+
{
215+
return uint128_t(low & rhs.low, high & rhs.high);
216+
}
217+
uint128_t operator^(const uint128_t& rhs) const
218+
{
219+
return uint128_t(low ^ rhs.low, high ^ rhs.high);
220+
}
221+
uint128_t operator~() const { return uint128_t(~low, ~high); }
222+
223+
// --- Comparison Operators ---
224+
bool operator==(const uint128_t& rhs) const
225+
{
226+
return low == rhs.low && high == rhs.high;
227+
}
228+
bool operator!=(const uint128_t& rhs) const { return !(*this == rhs); }
229+
bool operator<(const uint128_t& rhs) const
230+
{
231+
return high < rhs.high || (high == rhs.high && low < rhs.low);
232+
}
233+
bool operator>(const uint128_t& rhs) const { return rhs < *this; }
234+
bool operator<=(const uint128_t& rhs) const { return !(*this > rhs); }
235+
bool operator>=(const uint128_t& rhs) const { return !(*this < rhs); }
236+
};

test/cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ function(proxsuite_add_test name)
5151
endif()
5252
endfunction()
5353

54+
proxsuite_add_test(uint128)
5455
proxsuite_add_test(cvxpy)
5556
proxsuite_add_test(dense_backward)
5657
proxsuite_add_test(dense_qp_eq)

0 commit comments

Comments
 (0)