Skip to content

Commit 7cf683a

Browse files
authored
Merge pull request #16 from a41-official/perf/optimize-mont-ops
perf: optimize montgomery operations
2 parents e9243ad + d6b6041 commit 7cf683a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+842
-452
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
uses: actions/checkout@v4
2525

2626
- name: Setup Bazelisk
27-
uses: bazelbuild/setup-bazelisk@v2
27+
uses: bazel-contrib/setup-[email protected]
2828

2929
- name: Mount Bazel Cache
3030
uses: actions/cache@v4

WORKSPACE.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
workspace(name = "zkir")
1616

17+
load("//bazel:zkir_deps.bzl", "zkir_deps")
18+
19+
zkir_deps()
20+
1721
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1822

1923
LLVM_COMMIT = "f8287f6c373fcf993643dd6f0e30dde304c1be73"

bazel/zkir_deps.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""
2+
This module configures dependencies for the ZKIR project.
3+
"""
4+
5+
load("//third_party/omp:omp_configure.bzl", "omp_configure")
6+
7+
def zkir_deps():
8+
omp_configure(name = "local_config_omp")

benchmark/BenchmarkUtils.h

Lines changed: 227 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,25 @@
33

44
#include <cstdint>
55
#include <cstdlib>
6+
#include <iomanip>
7+
#include <sstream>
8+
#include <string>
9+
10+
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
11+
#define HOST_IS_LITTLE_ENDIAN 1
12+
#define HOST_IS_BIG_ENDIAN 0
13+
#elif defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
14+
#define HOST_IS_LITTLE_ENDIAN 0
15+
#define HOST_IS_BIG_ENDIAN 1
16+
#elif defined(_WIN32) // Check Windows after standard macros
17+
#define HOST_IS_LITTLE_ENDIAN 1
18+
#define HOST_IS_BIG_ENDIAN 0
19+
#else
20+
#warning \
21+
"Cannot determine host endianness at compile time. Assuming little-endian."
22+
#define HOST_IS_LITTLE_ENDIAN 1
23+
#define HOST_IS_BIG_ENDIAN 0
24+
#endif
625

726
namespace zkir {
827
namespace benchmark {
@@ -13,7 +32,7 @@ template <typename T>
1332
class Memref {
1433
public:
1534
Memref(size_t h, size_t w) {
16-
allocatedPtr = reinterpret_cast<T*>(malloc(sizeof(T) * w * h));
35+
allocatedPtr = reinterpret_cast<T *>(malloc(sizeof(T) * w * h));
1736
alignedPtr = allocatedPtr;
1837

1938
offset = 0;
@@ -23,20 +42,224 @@ class Memref {
2342
strides[1] = 1;
2443
}
2544

26-
T* pget(size_t i, size_t j) const {
45+
T *pget(size_t i, size_t j) const {
2746
return &alignedPtr[offset + i * strides[0] + j * strides[1]];
2847
}
2948

3049
T get(size_t i, size_t j) const { return *pget(i, j); }
3150

3251
private:
33-
T* allocatedPtr;
34-
T* alignedPtr;
52+
T *allocatedPtr;
53+
T *alignedPtr;
3554
size_t offset;
3655
size_t sizes[2];
3756
size_t strides[2];
3857
};
3958

59+
namespace {
60+
// Helper to parse a single hex character (case-insensitive)
61+
// Throws std::invalid_argument if the character is not a valid hex digit.
62+
inline uint8_t parseHexDigit(char c) {
63+
if (c >= '0' && c <= '9') {
64+
return static_cast<uint8_t>(c - '0');
65+
}
66+
if (c >= 'a' && c <= 'f') {
67+
return static_cast<uint8_t>(c - 'a' + 10);
68+
}
69+
if (c >= 'A' && c <= 'F') {
70+
return static_cast<uint8_t>(c - 'A' + 10);
71+
}
72+
throw std::invalid_argument(
73+
"Invalid hexadecimal character encountered in string.");
74+
}
75+
} // namespace
76+
77+
// Represents a large unsigned integer using an array of 64-bit limbs.
78+
// Uses the platform's native endianness for limb storage and operations,
79+
template <size_t kLimbCount>
80+
struct BigInt {
81+
static_assert(kLimbCount > 0, "BigInt must have at least one limb.");
82+
uint64_t limbs[kLimbCount];
83+
84+
static BigInt fromHexString(std::string_view hexStr) {
85+
BigInt value;
86+
// Prepare string view - remove optional prefix "0x" or "0X"
87+
if (hexStr.length() >= 2 && hexStr[0] == '0' &&
88+
(hexStr[1] == 'x' || hexStr[1] == 'X')) {
89+
hexStr.remove_prefix(2);
90+
}
91+
92+
// Remove leading zeros
93+
size_t firstDigit = hexStr.find_first_not_of('0');
94+
if (firstDigit == std::string_view::npos) {
95+
// Value is 0
96+
value.clear();
97+
return value;
98+
}
99+
// Create view of the relevant digits
100+
std::string_view digitsView = hexStr.substr(firstDigit);
101+
const size_t numDigits = digitsView.length();
102+
103+
// Check length against capacity
104+
const size_t maxDigits = kLimbCount * 16;
105+
if (numDigits > maxDigits) {
106+
throw std::overflow_error("Hex string value exceeds BigInt capacity (" +
107+
std::to_string(numDigits) + " digits > " +
108+
std::to_string(maxDigits) + " max).");
109+
}
110+
111+
// Parse right-to-left, placing limbs based on host endianness
112+
uint64_t currentLimbValue = 0;
113+
int bitsInCurrentLimb = 0;
114+
size_t currentLimbWriteIndex = 0;
115+
116+
// Determine the starting index in the limbs array based on platform
117+
#if HOST_IS_LITTLE_ENDIAN
118+
// Start writing to limbs[0] (least significant limb)
119+
currentLimbWriteIndex = 0;
120+
#else
121+
// Calculate how many limbs will be needed based on actual digits
122+
// and start writing to the array index corresponding to the
123+
// most significant limb that will be filled.
124+
size_t numLimbsToFill = (numDigits + 15) / 16; // Ceiling division
125+
assert(numLimbsToFill <= kLimbCount &&
126+
"Logic error: numLimbsToFill exceeds kLimbCount");
127+
currentLimbWriteIndex = kLimbCount - numLimbsToFill;
128+
#endif
129+
130+
// Iterate through the relevant digits from right to left
131+
for (size_t i = 0; i < numDigits; ++i) {
132+
// Process string from right (least significant hex digit) to left
133+
char c = digitsView[numDigits - 1 - i];
134+
// parseHexDigit throws std::invalid_argument on error
135+
uint8_t digitValue = parseHexDigit(c);
136+
137+
// Add the 4 bits of the digit to the current limb value at the correct
138+
// bit position
139+
currentLimbValue |=
140+
(static_cast<uint64_t>(digitValue) << bitsInCurrentLimb);
141+
bitsInCurrentLimb += 4;
142+
143+
// If limb is full (64 bits = 16 hex digits) or it's the last digit of the
144+
// string
145+
if (bitsInCurrentLimb == 64 || i == numDigits - 1) {
146+
// Write the completed or final partial limb
147+
value.limbs[currentLimbWriteIndex] = currentLimbValue;
148+
149+
// Move to the next limb index slot (index increases for both LE/BE
150+
// write sequences)
151+
currentLimbWriteIndex++;
152+
153+
// Reset for next limb
154+
currentLimbValue = 0;
155+
bitsInCurrentLimb = 0;
156+
}
157+
}
158+
return value;
159+
}
160+
161+
static BigInt randomLT(const BigInt &upper_bound, std::mt19937_64 &rng,
162+
std::uniform_int_distribution<uint64_t> &dist) {
163+
// Generate a random number less than the given upper bound.
164+
BigInt candidate;
165+
for (size_t j = 0; j < kLimbCount;) {
166+
candidate.limbs[j] = dist(rng);
167+
if (candidate.limbs[j] < upper_bound.limbs[j]) {
168+
j++;
169+
}
170+
}
171+
return candidate;
172+
}
173+
174+
static constexpr size_t getLimbCount() { return kLimbCount; }
175+
176+
bool operator<(const BigInt &other) const {
177+
#if HOST_IS_LITTLE_ENDIAN
178+
// Little-Endian: Compare from MOST significant limb (highest index) down
179+
for (int i = kLimbCount - 1; i >= 0; --i) {
180+
if (limbs[i] < other.limbs[i]) {
181+
return true;
182+
}
183+
if (limbs[i] > other.limbs[i]) {
184+
return false;
185+
}
186+
}
187+
#else // HOST_IS_BIG_ENDIAN
188+
// Big-Endian: Compare from MOST significant limb (lowest index) up
189+
for (size_t i = 0; i < kLimbCount; ++i) {
190+
if (limbs[i] < other.limbs[i]) {
191+
return true;
192+
}
193+
if (limbs[i] > other.limbs[i]) {
194+
return false;
195+
}
196+
}
197+
#endif
198+
// Numbers are equal.
199+
return false;
200+
}
201+
202+
bool operator==(const BigInt &other) const {
203+
for (size_t i = 0; i < kLimbCount; ++i) {
204+
if (limbs[i] != other.limbs[i]) {
205+
return false;
206+
}
207+
}
208+
return true;
209+
}
210+
211+
bool operator!=(const BigInt &other) const { return !(*this == other); }
212+
bool operator>(const BigInt &other) const { return other < *this; }
213+
bool operator<=(const BigInt &other) const { return !(other < *this); }
214+
bool operator>=(const BigInt &other) const { return !(*this < other); }
215+
216+
void clear() { std::fill(limbs, limbs + kLimbCount, 0); }
217+
bool isZero() const {
218+
for (size_t i = 0; i < kLimbCount; ++i) {
219+
if (limbs[i] != 0) {
220+
return false;
221+
}
222+
}
223+
return true;
224+
}
225+
226+
bool isOne() const {
227+
for (size_t i = 1; i < kLimbCount - 1; ++i) {
228+
if (limbs[i] != 0) {
229+
return false;
230+
}
231+
}
232+
#if HOST_IS_LITTLE_ENDIAN
233+
return limbs[0] == 1 && limbs[kLimbCount - 1] == 0;
234+
#else
235+
return limbs[0] == 0 && limbs[kLimbCount - 1] == 1;
236+
#endif
237+
}
238+
239+
std::string printHex() const {
240+
std::stringstream s;
241+
s << "0x";
242+
bool leadingZeros = true;
243+
244+
#if HOST_IS_BIG_ENDIAN
245+
for (size_t i = 0; i < kLimbCount; ++i) {
246+
if (leadingZeros && limbs[i] == 0 && i < kLimbCount - 1) continue;
247+
leadingZeros = false;
248+
s << std::hex << std::setw(16) << std::setfill('0') << limbs[i];
249+
}
250+
#else // HOST_IS_LITTLE_ENDIAN
251+
for (int i = kLimbCount - 1; i >= 0; --i) {
252+
if (leadingZeros && limbs[i] == 0 && i > 0) continue;
253+
leadingZeros = false;
254+
s << std::hex << std::setw(16) << std::setfill('0') << limbs[i];
255+
}
256+
#endif
257+
// Handle case where value is exactly 0
258+
if (leadingZeros) s << "0";
259+
return s.str();
260+
}
261+
};
262+
40263
} // namespace benchmark
41264
} // namespace zkir
42265

benchmark/benchmark.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ def zkir_benchmark_test(name, mlir_src, test_src, zkir_opt_flags = [], data = []
146146
":" + import_name,
147147
"@google_benchmark//:benchmark_main",
148148
"@googletest//:gtest",
149+
"@llvm-project//mlir:mlir_runner_utils",
150+
"@local_config_omp//:omp",
149151
],
152+
copts = ["-Xclang -fopenmp"],
153+
linkopts = ["-Xclang -fopenmp"],
150154
tags = tags,
151155
data = data + [generated_obj_name],
152156
**kwargs

benchmark/ntt/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ zkir_benchmark_test(
55
mlir_src = "ntt_benchmark.mlir",
66
tags = ["manual"],
77
test_src = ["ntt_benchmark_test.cc"],
8-
zkir_opt_flags = ["-poly-to-llvm"],
8+
zkir_opt_flags = ["-poly-to-omp"],
99
deps = [
1010
"//benchmark:BenchmarkUtils",
1111
],

0 commit comments

Comments
 (0)