Skip to content

Commit 55192f2

Browse files
authored
Fix pos (#45)
* Attempt to fix PoS issue, changing to uint32 * Change more variables to 32 bits * Script to check plots
1 parent cd266cf commit 55192f2

File tree

4 files changed

+131
-31
lines changed

4 files changed

+131
-31
lines changed

lib/chiapos/src/bits.hpp

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525

2626

2727
#define kBufSize 5
28-
#define kMaxSizeBits 65535
28+
29+
// 128 * 2^16. 2^16 values, each value can store 128 bits.
30+
#define kMaxSizeBits 8388608
2931

3032
// A stack vector of length 5, having the functions of std::vector needed for Bits.
3133
struct SmallVector {
@@ -63,17 +65,17 @@ struct SmallVector {
6365

6466

6567
// A stack vector of length 1024, having the functions of std::vector needed for Bits.
66-
68+
// The max number of Bits that can be stored is 1024 * 128
6769
struct ParkVector {
6870
ParkVector() {
6971
count_ = 0;
7072
}
7173

72-
uint128_t& operator[] (const uint16_t index) {
74+
uint128_t& operator[] (const uint32_t index) {
7375
return v_[index];
7476
}
7577

76-
uint128_t operator[] (const uint16_t index) const {
78+
uint128_t operator[] (const uint32_t index) const {
7779
return v_[index];
7880
}
7981

@@ -83,18 +85,18 @@ struct ParkVector {
8385

8486
ParkVector& operator = (const ParkVector& other) {
8587
count_ = other.count_;
86-
for (uint16_t i = 0; i < other.count_; i++)
88+
for (uint32_t i = 0; i < other.count_; i++)
8789
v_[i] = other.v_[i];
8890
return (*this);
8991
}
9092

91-
uint16_t size() const {
93+
uint32_t size() const {
9294
return count_;
9395
}
9496

9597
private:
9698
uint128_t v_[1024];
97-
uint16_t count_;
99+
uint32_t count_;
98100
};
99101

100102
/*
@@ -121,7 +123,7 @@ template <class T> class BitsGeneric {
121123

122124
// Converts from unit128_t to Bits. If the number of bits of value is smaller than size, adds 0 bits at the beginning.
123125
// i.e. Bits(5, 10) = 0000000101
124-
BitsGeneric<T>(uint128_t value, uint16_t size) {
126+
BitsGeneric<T>(uint128_t value, uint32_t size) {
125127
// TODO(mariano) remove
126128
if (size < 128 && value > ((uint128_t)1 << size)) {
127129
std::cout << "TOO BIG FOR BITS" << std::endl;
@@ -130,7 +132,7 @@ template <class T> class BitsGeneric {
130132
this->last_size_ = 0;
131133
if (size > 128) {
132134
// Get number of extra 0s added at the beginning.
133-
uint16_t zeros = size - Util::GetSizeBits(value);
135+
uint32_t zeros = size - Util::GetSizeBits(value);
134136
// Add a full group of 0s (length 128)
135137
while (zeros > 128) {
136138
AppendValue(0, 128);
@@ -147,12 +149,12 @@ template <class T> class BitsGeneric {
147149

148150
// Copy the content of another Bits object. If the size of the other Bits object is smaller
149151
// than 'size', adds 0 bits at the beginning.
150-
BitsGeneric<T>(const BitsGeneric<T>& other, uint16_t size) {
151-
uint16_t total_size = other.GetSize();
152+
BitsGeneric<T>(const BitsGeneric<T>& other, uint32_t size) {
153+
uint32_t total_size = other.GetSize();
152154
this->last_size_ = 0;
153155
assert(size >= total_size);
154156
// Add the extra 0 bits at the beginning.
155-
uint16_t extra_space = size - total_size;
157+
uint32_t extra_space = size - total_size;
156158
while (extra_space >= 128) {
157159
AppendValue(0, 128);
158160
extra_space -= 128;
@@ -161,14 +163,14 @@ template <class T> class BitsGeneric {
161163
AppendValue(0, extra_space);
162164
// Copy the Bits object element by element, and append it to the current Bits object.
163165
if (other.values_.size() > 0) {
164-
for (uint8_t i = 0; i < other.values_.size() - 1; i++)
166+
for (uint32_t i = 0; i < other.values_.size() - 1; i++)
165167
AppendValue(other.values_[i], 128);
166168
AppendValue(other.values_[other.values_.size() - 1], other.last_size_);
167169
}
168170
}
169171

170172
// Converts bytes to bits.
171-
BitsGeneric<T>(const uint8_t* big_endian_bytes, uint32_t num_bytes, uint16_t size_bits) {
173+
BitsGeneric<T>(const uint8_t* big_endian_bytes, uint32_t num_bytes, uint32_t size_bits) {
172174
this->last_size_ = 0;
173175
uint32_t extra_space = size_bits - num_bytes * 8;
174176
// Add the extra 0 bits at the beginning.
@@ -210,12 +212,12 @@ template <class T> class BitsGeneric {
210212
}
211213
BitsGeneric<T> result;
212214
if (values_.size() > 0) {
213-
for (uint8_t i = 0; i < values_.size() - 1; i++)
215+
for (uint32_t i = 0; i < values_.size() - 1; i++)
214216
result.AppendValue(values_[i], 128);
215217
result.AppendValue(values_[values_.size() - 1], last_size_);
216218
}
217219
if (b.values_.size() > 0) {
218-
for (uint8_t i = 0; i < b.values_.size() - 1; i++)
220+
for (uint32_t i = 0; i < b.values_.size() - 1; i++)
219221
result.AppendValue(b.values_[i], 128);
220222
result.AppendValue(b.values_[b.values_.size() - 1], b.last_size_);
221223
}
@@ -226,7 +228,7 @@ template <class T> class BitsGeneric {
226228
template <class T2>
227229
BitsGeneric<T>& operator += (const BitsGeneric<T2>& b) {
228230
if (b.values_.size() > 0) {
229-
for (uint8_t i = 0; i < b.values_.size() - 1; i++)
231+
for (uint32_t i = 0; i < b.values_.size() - 1; i++)
230232
this->AppendValue(b.values_[i], 128);
231233
this->AppendValue(b.values_[b.values_.size() - 1], b.last_size_);
232234
}
@@ -251,7 +253,7 @@ template <class T> class BitsGeneric {
251253
values_[i]++;
252254
// Buckets that were full of 1 bits turn all to 0 bits.
253255
// (i.e. 10011111 + 1 = 10100000)
254-
for (uint16_t j = i + 1; j < values_.size(); j++)
256+
for (uint32_t j = i + 1; j < values_.size(); j++)
255257
values_[j] = 0;
256258
break;
257259
}
@@ -286,7 +288,7 @@ template <class T> class BitsGeneric {
286288
(uint128_t)std::numeric_limits<uint64_t> :: max();
287289
// All buckets that were previously 0, now become full of 1s.
288290
// (i.e. 1010000 - 1 = 1001111)
289-
for (uint16_t j = i + 1; j < values_.size() - 1; j++)
291+
for (uint32_t j = i + 1; j < values_.size() - 1; j++)
290292
values_[j] = limit;
291293
values_[values_.size() - 1] = (last_size_ == 128) ? limit :
292294
((static_cast<uint128_t>(1) << last_size_) - 1);
@@ -309,7 +311,7 @@ template <class T> class BitsGeneric {
309311
assert(GetSize() == other.GetSize());
310312
BitsGeneric<T> res;
311313
// Xoring individual bits is the same as xor-ing chunks of bits.
312-
for (uint16_t i = 0; i < values_.size(); i++)
314+
for (uint32_t i = 0; i < values_.size(); i++)
313315
res.values_.push_back(values_[i] ^ other.values_[i]);
314316
res.last_size_ = last_size_;
315317
return res;
@@ -361,7 +363,7 @@ template <class T> class BitsGeneric {
361363
}
362364

363365
// Same as 'Slice', but result fits into an uint64_t. Used for memory optimization.
364-
uint64_t SliceBitsToInt(int16_t start_index, int16_t end_index) const {
366+
uint64_t SliceBitsToInt(int32_t start_index, int32_t end_index) const {
365367
/*if (end_index > GetSize()) {
366368
end_index = GetSize();
367369
}
@@ -396,7 +398,7 @@ template <class T> class BitsGeneric {
396398
// Append 0s to complete the last byte.
397399
uint8_t shift = Util::ByteAlign(last_size_) - last_size_;
398400
uint128_t val = values_[values_.size() - 1] << (shift);
399-
uint16_t cnt = 0;
401+
uint32_t cnt = 0;
400402
// Extract byte-by-byte from the last bucket.
401403
uint8_t iterations = last_size_ / 8;
402404
if (last_size_ % 8)
@@ -407,7 +409,7 @@ template <class T> class BitsGeneric {
407409
}
408410
// Extract the full buckets, byte by byte.
409411
if (values_.size() >= 2) {
410-
for (int16_t i = values_.size() - 2; i >= 0; i--) {
412+
for (int32_t i = values_.size() - 2; i >= 0; i--) {
411413
uint128_t val = values_[i];
412414
for (uint8_t j = 0; j < 16; j++) {
413415
buffer[cnt++] = (val & 0xff);
@@ -419,7 +421,7 @@ template <class T> class BitsGeneric {
419421
if(cnt<=1)return; // No need to reverse anything
420422

421423
// Since we extracted from end to beginning, bytes are in reversed order. Reverse everything.
422-
uint16_t left = 0, right = cnt - 1;
424+
uint32_t left = 0, right = cnt - 1;
423425
while (left < right) {
424426
std::swap(buffer[left], buffer[right]);
425427
left++;
@@ -429,9 +431,9 @@ template <class T> class BitsGeneric {
429431

430432
std::string ToString() const {
431433
std::string str = "";
432-
for (uint16_t i = 0; i < values_.size(); i++) {
434+
for (uint32_t i = 0; i < values_.size(); i++) {
433435
uint128_t val = values_[i];
434-
uint16_t size = (i == values_.size() - 1) ? last_size_ : 128;
436+
uint32_t size = (i == values_.size() - 1) ? last_size_ : 128;
435437
std::string str_bucket = "";
436438
for (int i = 0; i < size; i++) {
437439
if (val % 2)
@@ -455,10 +457,10 @@ template <class T> class BitsGeneric {
455457
return values_[0];
456458
}
457459

458-
uint16_t GetSize() const {
460+
uint32_t GetSize() const {
459461
if (values_.size() == 0) return 0;
460462
// Full buckets contain each 128 bits, last one contains only 'last_size_' bits.
461-
return (values_.size() - 1) * 128 + last_size_;
463+
return ((uint32_t)values_.size() - 1) * 128 + last_size_;
462464
}
463465

464466
void AppendValue(uint128_t value, uint8_t length) {
@@ -527,7 +529,7 @@ bool operator==(const BitsGeneric<T>& lhs, const BitsGeneric<T>& rhs) {
527529
if (lhs.GetSize() != rhs.GetSize()) {
528530
return false;
529531
}
530-
for (uint16_t i = 0; i < lhs.values_.size(); i++) {
532+
for (uint32_t i = 0; i < lhs.values_.size(); i++) {
531533
if (lhs.values_[i] != rhs.values_[i]) {
532534
return false;
533535
}
@@ -539,7 +541,7 @@ template <class T>
539541
bool operator<(const BitsGeneric<T>& lhs, const BitsGeneric<T>& rhs) {
540542
if (lhs.GetSize() != rhs.GetSize())
541543
throw std::string("Different sizes!");
542-
for (uint16_t i = 0; i < lhs.values_.size(); i++) {
544+
for (uint32_t i = 0; i < lhs.values_.size(); i++) {
543545
if (lhs.values_[i] < rhs.values_[i])
544546
return true;
545547
if (lhs.values_[i] > rhs.values_[i])
@@ -552,7 +554,7 @@ template <class T>
552554
bool operator>(const BitsGeneric<T>& lhs, const BitsGeneric<T>& rhs) {
553555
if (lhs.GetSize() != rhs.GetSize())
554556
throw std::string("Different sizes!");
555-
for (uint16_t i = 0; i < lhs.values_.size(); i++) {
557+
for (uint32_t i = 0; i < lhs.values_.size(); i++) {
556558
if (lhs.values_[i] > rhs.values_[i])
557559
return true;
558560
if (lhs.values_[i] < rhs.values_[i])

lib/chiapos/src/util.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef SRC_CPP_UTIL_HPP_
1616
#define SRC_CPP_UTIL_HPP_
1717

18+
#include <random>
1819
#include <iostream>
1920
#include <fstream>
2021
#include <iomanip>
@@ -206,6 +207,15 @@ class Util {
206207
return sum;
207208
}
208209

210+
static void GetRandomBytes(uint8_t* buf, uint32_t num_bytes) {
211+
std::random_device rd;
212+
std::mt19937 mt(rd());
213+
std::uniform_real_distribution<double> dist(0, 256);
214+
for (uint32_t i = 0; i < num_bytes; i++) {
215+
buf[i] = static_cast<uint32_t>(floor(dist(mt))) % 256; // Mod in case we generate the random number 256:
216+
}
217+
}
218+
209219
static uint64_t find_islands(std::vector<std::pair<uint64_t, uint64_t> > edges) {
210220
std::map<uint64_t, std::vector<uint64_t> > edge_indeces;
211221
for (uint64_t edge_index = 0; edge_index < edges.size(); edge_index++) {

lib/chiapos/tests/test.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,33 @@ TEST_CASE("Bits") {
141141
m.ToBytes(buf);
142142
REQUIRE(buf[0] == (5 << 5));
143143
}
144+
SECTION("Park Bits") {
145+
uint32_t num_bytes = 16000;
146+
uint8_t* buf = new uint8_t[num_bytes];
147+
uint8_t* buf_2 = new uint8_t[num_bytes];
148+
Util::GetRandomBytes(buf, num_bytes);
149+
ParkBits my_bits = ParkBits(buf, num_bytes, num_bytes*8);
150+
my_bits.ToBytes(buf_2);
151+
for (uint32_t i = 0; i < num_bytes; i++) {
152+
REQUIRE(buf[i] == buf_2[i]);
153+
}
154+
delete[] buf;
155+
delete[] buf_2;
156+
}
157+
158+
SECTION("Large Bits") {
159+
uint32_t num_bytes = 200000;
160+
uint8_t* buf = new uint8_t[num_bytes];
161+
uint8_t* buf_2 = new uint8_t[num_bytes];
162+
Util::GetRandomBytes(buf, num_bytes);
163+
LargeBits my_bits = LargeBits(buf, num_bytes, num_bytes*8);
164+
my_bits.ToBytes(buf_2);
165+
for (uint32_t i = 0; i < num_bytes; i++) {
166+
REQUIRE(buf[i] == buf_2[i]);
167+
}
168+
delete[] buf;
169+
delete[] buf_2;
170+
}
144171
}
145172

146173
class FakeDisk : public Disk {

scripts/check_plots.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import argparse
2+
import os
3+
from hashlib import sha256
4+
5+
from blspy import PrivateKey, PublicKey
6+
from yaml import safe_load
7+
8+
from chiapos import DiskProver, Verifier
9+
from definitions import ROOT_DIR
10+
from src.types.proof_of_space import ProofOfSpace
11+
from src.types.sized_bytes import bytes32
12+
13+
plot_root = os.path.join(ROOT_DIR, "plots")
14+
plot_config_filename = os.path.join(ROOT_DIR, "config", "plots.yaml")
15+
16+
17+
def main():
18+
"""
19+
Script for checking all plots in the plots.yaml file. Specify a number of challenge to test for each plot.
20+
"""
21+
22+
parser = argparse.ArgumentParser(description="Chia plot checking script.")
23+
parser.add_argument("-n", "--num", help="Number of challenges", type=int, default=1000)
24+
args = parser.parse_args()
25+
26+
v = Verifier()
27+
if os.path.isfile(plot_config_filename):
28+
plot_config = safe_load(open(plot_config_filename, "r"))
29+
for plot_filename, plot_info in plot_config["plots"].items():
30+
plot_seed: bytes32 = ProofOfSpace.calculate_plot_seed(
31+
PublicKey.from_bytes(bytes.fromhex(plot_info["pool_pk"])),
32+
PrivateKey.from_bytes(bytes.fromhex(plot_info["sk"])).get_public_key()
33+
)
34+
# Tries relative path
35+
full_path: str = os.path.join(plot_root, plot_filename)
36+
if not os.path.isfile(full_path):
37+
# Tries absolute path
38+
full_path: str = plot_filename
39+
if not os.path.isfile(full_path):
40+
print(f"Plot file {full_path} not found.")
41+
continue
42+
pr = DiskProver(full_path)
43+
44+
total_proofs = 0
45+
try:
46+
for i in range(args.num):
47+
challenge = sha256(i.to_bytes(32, "big")).digest()
48+
for index, quality in enumerate(pr.get_qualities_for_challenge(challenge)):
49+
proof = pr.get_full_proof(challenge, index)
50+
total_proofs += 1
51+
ver_quality = v.validate_proof(plot_seed, pr.get_size(), challenge, proof)
52+
assert(quality == ver_quality)
53+
except BaseException as e:
54+
print(f"{type(e)}: {e} error in proving/verifying for plot {plot_filename}")
55+
print(f"{plot_filename}: Proofs {total_proofs} / {args.num}, {round(total_proofs/float(args.num), 4)}")
56+
else:
57+
print(f"Not plot file found at {plot_config_filename}")
58+
59+
60+
if __name__ == "__main__":
61+
main()

0 commit comments

Comments
 (0)