Skip to content

Commit d7eab70

Browse files
authored
fix: limit mulh kernel threads (#2337)
- limit threads launched by mulh cuda kernel - use `inline constexpr` for compile time constants instead of `static const` in cuda kernels these are unrelated changes picked from #2338
1 parent 4c7c5d9 commit d7eab70

File tree

6 files changed

+72
-72
lines changed

6 files changed

+72
-72
lines changed

crates/circuits/primitives/cuda/include/primitives/constants.h

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,91 +3,91 @@
33
#include <cstddef>
44

55
namespace riscv {
6-
static const size_t RV32_REGISTER_NUM_LIMBS = 4;
7-
static const size_t RV32_CELL_BITS = 8;
8-
static const size_t RV_J_TYPE_IMM_BITS = 21;
6+
inline constexpr size_t RV32_REGISTER_NUM_LIMBS = 4;
7+
inline constexpr size_t RV32_CELL_BITS = 8;
8+
inline constexpr size_t RV_J_TYPE_IMM_BITS = 21;
99

10-
static const size_t RV32_IMM_AS = 0;
10+
inline constexpr size_t RV32_IMM_AS = 0;
1111
} // namespace riscv
1212

1313
namespace program {
14-
static const size_t PC_BITS = 30;
15-
static const size_t DEFAULT_PC_STEP = 4;
14+
inline constexpr size_t PC_BITS = 30;
15+
inline constexpr size_t DEFAULT_PC_STEP = 4;
1616
} // namespace program
1717

1818
namespace native {
19-
static const size_t AS_IMMEDIATE = 0;
20-
static const size_t AS_NATIVE = 4;
21-
static const size_t EXT_DEG = 4;
22-
static const size_t BETA = 11;
19+
inline constexpr size_t AS_IMMEDIATE = 0;
20+
inline constexpr size_t AS_NATIVE = 4;
21+
inline constexpr size_t EXT_DEG = 4;
22+
inline constexpr size_t BETA = 11;
2323
} // namespace native
2424

2525
namespace poseidon2 {
26-
static const size_t CHUNK = 8;
26+
inline constexpr size_t CHUNK = 8;
2727
} // namespace poseidon2
2828

2929
namespace p3_keccak_air {
30-
static const size_t NUM_ROUNDS = 24;
31-
static const size_t BITS_PER_LIMB = 16;
32-
static const size_t U64_LIMBS = 64 / BITS_PER_LIMB;
33-
static const size_t RATE_BITS = 1088;
34-
static const size_t RATE_LIMBS = RATE_BITS / BITS_PER_LIMB;
30+
inline constexpr size_t NUM_ROUNDS = 24;
31+
inline constexpr size_t BITS_PER_LIMB = 16;
32+
inline constexpr size_t U64_LIMBS = 64 / BITS_PER_LIMB;
33+
inline constexpr size_t RATE_BITS = 1088;
34+
inline constexpr size_t RATE_LIMBS = RATE_BITS / BITS_PER_LIMB;
3535
} // namespace p3_keccak_air
3636

3737
namespace keccak256 {
3838
/// Total number of sponge bytes: number of rate bytes + number of capacity bytes.
39-
static const size_t KECCAK_WIDTH_BYTES = 200;
39+
inline constexpr size_t KECCAK_WIDTH_BYTES = 200;
4040
/// Total number of 16-bit limbs in the sponge.
41-
static const size_t KECCAK_WIDTH_U16S = KECCAK_WIDTH_BYTES / 2;
41+
inline constexpr size_t KECCAK_WIDTH_U16S = KECCAK_WIDTH_BYTES / 2;
4242
/// Number of rate bytes.
43-
static const size_t KECCAK_RATE_BYTES = 136;
43+
inline constexpr size_t KECCAK_RATE_BYTES = 136;
4444
/// Number of 16-bit rate limbs.
45-
static const size_t KECCAK_RATE_U16S = KECCAK_RATE_BYTES / 2;
45+
inline constexpr size_t KECCAK_RATE_U16S = KECCAK_RATE_BYTES / 2;
4646
/// Number of absorb rounds, equal to rate in u64s.
47-
static const size_t NUM_ABSORB_ROUNDS = KECCAK_RATE_BYTES / 8;
47+
inline constexpr size_t NUM_ABSORB_ROUNDS = KECCAK_RATE_BYTES / 8;
4848
/// Number of capacity bytes.
49-
static const size_t KECCAK_CAPACITY_BYTES = 64;
49+
inline constexpr size_t KECCAK_CAPACITY_BYTES = 64;
5050
/// Number of 16-bit capacity limbs.
51-
static const size_t KECCAK_CAPACITY_U16S = KECCAK_CAPACITY_BYTES / 2;
51+
inline constexpr size_t KECCAK_CAPACITY_U16S = KECCAK_CAPACITY_BYTES / 2;
5252
/// Number of output digest bytes used during the squeezing phase.
53-
static const size_t KECCAK_DIGEST_BYTES = 32;
53+
inline constexpr size_t KECCAK_DIGEST_BYTES = 32;
5454
/// Number of 64-bit digest limbs.
55-
static const size_t KECCAK_DIGEST_U64S = KECCAK_DIGEST_BYTES / 8;
55+
inline constexpr size_t KECCAK_DIGEST_U64S = KECCAK_DIGEST_BYTES / 8;
5656

5757
// ==== Constants for register/memory adapter ====
5858
/// Register reads to get dst, src, len
59-
static const size_t KECCAK_REGISTER_READS = 3;
59+
inline constexpr size_t KECCAK_REGISTER_READS = 3;
6060
/// Number of cells to read/write in a single memory access
61-
static const size_t KECCAK_WORD_SIZE = 4;
61+
inline constexpr size_t KECCAK_WORD_SIZE = 4;
6262
/// Memory reads for absorb per row
63-
static const size_t KECCAK_ABSORB_READS = KECCAK_RATE_BYTES / KECCAK_WORD_SIZE;
63+
inline constexpr size_t KECCAK_ABSORB_READS = KECCAK_RATE_BYTES / KECCAK_WORD_SIZE;
6464
/// Memory writes for digest per row
65-
static const size_t KECCAK_DIGEST_WRITES = KECCAK_DIGEST_BYTES / KECCAK_WORD_SIZE;
65+
inline constexpr size_t KECCAK_DIGEST_WRITES = KECCAK_DIGEST_BYTES / KECCAK_WORD_SIZE;
6666
/// keccakf parameters
67-
static const size_t KECCAK_ROUND = 24;
68-
static const size_t KECCAK_STATE_SIZE = 25;
69-
static const size_t KECCAK_Q_SIZE = 192;
67+
inline constexpr size_t KECCAK_ROUND = 24;
68+
inline constexpr size_t KECCAK_STATE_SIZE = 25;
69+
inline constexpr size_t KECCAK_Q_SIZE = 192;
7070
/// From memory config
71-
static const size_t KECCAK_POINTER_MAX_BITS = 29;
71+
inline constexpr size_t KECCAK_POINTER_MAX_BITS = 29;
7272
} // namespace keccak256
7373

7474
namespace mod_builder {
75-
static const size_t MAX_LIMBS = 97;
75+
inline constexpr size_t MAX_LIMBS = 97;
7676
} // namespace mod_builder
7777

7878
namespace sha256 {
79-
static const size_t SHA256_BLOCK_BITS = 512;
80-
static const size_t SHA256_BLOCK_U8S = 64;
81-
static const size_t SHA256_BLOCK_WORDS = 16;
82-
static const size_t SHA256_WORD_U8S = 4;
83-
static const size_t SHA256_WORD_BITS = 32;
84-
static const size_t SHA256_WORD_U16S = 2;
85-
static const size_t SHA256_HASH_WORDS = 8;
86-
static const size_t SHA256_NUM_READ_ROWS = 4;
87-
static const size_t SHA256_ROWS_PER_BLOCK = 17;
88-
static const size_t SHA256_ROUNDS_PER_ROW = 4;
89-
static const size_t SHA256_ROW_VAR_CNT = 5;
90-
static const size_t SHA256_REGISTER_READS = 3;
91-
static const size_t SHA256_READ_SIZE = 16;
92-
static const size_t SHA256_WRITE_SIZE = 32;
93-
} // namespace sha256
79+
inline constexpr size_t SHA256_BLOCK_BITS = 512;
80+
inline constexpr size_t SHA256_BLOCK_U8S = 64;
81+
inline constexpr size_t SHA256_BLOCK_WORDS = 16;
82+
inline constexpr size_t SHA256_WORD_U8S = 4;
83+
inline constexpr size_t SHA256_WORD_BITS = 32;
84+
inline constexpr size_t SHA256_WORD_U16S = 2;
85+
inline constexpr size_t SHA256_HASH_WORDS = 8;
86+
inline constexpr size_t SHA256_NUM_READ_ROWS = 4;
87+
inline constexpr size_t SHA256_ROWS_PER_BLOCK = 17;
88+
inline constexpr size_t SHA256_ROUNDS_PER_ROW = 4;
89+
inline constexpr size_t SHA256_ROW_VAR_CNT = 5;
90+
inline constexpr size_t SHA256_REGISTER_READS = 3;
91+
inline constexpr size_t SHA256_READ_SIZE = 16;
92+
inline constexpr size_t SHA256_WRITE_SIZE = 32;
93+
} // namespace sha256

crates/circuits/primitives/cuda/include/primitives/less_than.cuh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "histogram.cuh"
44
#include "fp_array.cuh"
55

6-
static const size_t AUX_LEN = 2;
6+
inline constexpr size_t AUX_LEN = 2;
77

88
template <typename T, size_t AUX_LEN = AUX_LEN> struct LessThanAuxCols {
99
T lower_decomp[AUX_LEN];
@@ -19,14 +19,14 @@ template <typename T, size_t NUM, size_t AUX_LEN = AUX_LEN> struct LessThanArray
1919
namespace AssertLessThan {
2020
/**
2121
* @brief Generates columns needed to constrain that x < y
22-
*
22+
*
2323
* @section Trace Context Parameters
2424
* @param rc Range checker histogram reference
2525
* @param max_bits Maximum number of bits the respresntation of x and y can be
2626
* @param x First value to compare (must be strictly less than y)
2727
* @param y Second value to compare
2828
* @param lower_decomp_len Number of columns needed to constrain x < y
29-
*
29+
*
3030
* @section Mutable Column Parameters
3131
* @param lower_decomp Columns used to constrain x < y
3232
*/
@@ -45,14 +45,14 @@ __device__ __forceinline__ void generate_subrow(
4545
namespace IsLessThan {
4646
/**
4747
* @brief Generates columns needed to constrain that out_flag == (x < y)
48-
*
48+
*
4949
* @section Trace Context Parameters
5050
* @param rc Range checker histogram reference
5151
* @param max_bits Maximum number of bits the respresntation of x and y can be
5252
* @param x First value to compare
5353
* @param y Second value to compare
5454
* @param lower_decomp_len Number of columns needed to constrain out_flag == (x < y)
55-
*
55+
*
5656
* @section Mutable Column Parameters
5757
* @param lower_decomp Columns used to constrain out_flag == (x < y)
5858
* @param out_flag Boolean value equal to x < y
@@ -77,15 +77,15 @@ namespace IsLessThanArray {
7777
/**
7878
* @brief Generates columns needed to constrain that out_flag == (x < y),
7979
* where x and y are represented by array_len limbs.
80-
*
80+
*
8181
* @section Trace Context Parameters
8282
* @param rc Range checker histogram reference
8383
* @param max_bits Maximum number of bits each limb of x and y can be
8484
* @param x First value to compare
8585
* @param y Second value to compare
8686
* @param array_len Number of limbs to represent x and y
8787
* @param aux_len Number of additional columns needed to constrain outflag == (x < y)
88-
*
88+
*
8989
* @section Mutable Column Parameters
9090
* @param diff_marker Array that marks the most significant limb difference in x and y
9191
* @param diff_inv Field inverse of the first differing y[i] - x[i], or 0

crates/vm/cuda/src/system/boundary.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#include "primitives/trace_access.h"
66
#include <cassert>
77

8-
static const size_t PERSISTENT_CHUNK = 8;
9-
static const size_t VOLATILE_CHUNK = 1;
8+
inline constexpr size_t PERSISTENT_CHUNK = 8;
9+
inline constexpr size_t VOLATILE_CHUNK = 1;
1010

1111
template <size_t CHUNK> struct BoundaryRecord {
1212
uint32_t address_space;
@@ -24,8 +24,8 @@ template <typename T> struct PersistentBoundaryCols {
2424
T timestamp;
2525
};
2626

27-
static const size_t ADDR_ELTS = 2;
28-
static const size_t NUM_AS_LIMBS = 1;
27+
inline constexpr size_t ADDR_ELTS = 2;
28+
inline constexpr size_t NUM_AS_LIMBS = 1;
2929

3030
template <typename T> struct VolatileBoundaryCols {
3131
T address_space_limbs[NUM_AS_LIMBS];
@@ -61,7 +61,7 @@ __global__ void cukernel_persistent_boundary_tracegen(
6161
// TODO better address space handling
6262
FpArray<8> init_values;
6363
if (initial_mem[record.address_space - 1]) {
64-
init_values =
64+
init_values =
6565
record.address_space == 4
6666
? FpArray<8>::from_raw_array(
6767
reinterpret_cast<uint32_t const *>(

extensions/native/circuit/cuda/include/native/fri.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ template <typename T> struct Instruction2Cols {
103103
};
104104

105105
// Size constants for the three column types
106-
static const size_t WORKLOAD_SIZE = sizeof(WorkloadCols<uint8_t>);
107-
static const size_t INSN1_SIZE = sizeof(Instruction1Cols<uint8_t>);
108-
static const size_t INSN2_SIZE = sizeof(Instruction2Cols<uint8_t>);
109-
static const size_t OVERALL_SIZE = std::max({WORKLOAD_SIZE, INSN1_SIZE, INSN2_SIZE});
106+
inline constexpr size_t WORKLOAD_SIZE = sizeof(WorkloadCols<uint8_t>);
107+
inline constexpr size_t INSN1_SIZE = sizeof(Instruction1Cols<uint8_t>);
108+
inline constexpr size_t INSN2_SIZE = sizeof(Instruction2Cols<uint8_t>);
109+
inline constexpr size_t OVERALL_SIZE = std::max({WORKLOAD_SIZE, INSN1_SIZE, INSN2_SIZE});

extensions/native/circuit/cuda/src/poseidon2.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88
using namespace poseidon2;
99

10-
static const size_t WIDTH = 16;
11-
static const size_t SBOX_DEGREE = Poseidon2DefaultParams::SBOX_DEGREE;
12-
static const size_t HALF_FULL_ROUNDS = Poseidon2DefaultParams::HALF_FULL_ROUNDS;
13-
static const size_t PARTIAL_ROUNDS = Poseidon2DefaultParams::PARTIAL_ROUNDS;
10+
inline constexpr size_t WIDTH = 16;
11+
inline constexpr size_t SBOX_DEGREE = Poseidon2DefaultParams::SBOX_DEGREE;
12+
inline constexpr size_t HALF_FULL_ROUNDS = Poseidon2DefaultParams::HALF_FULL_ROUNDS;
13+
inline constexpr size_t PARTIAL_ROUNDS = Poseidon2DefaultParams::PARTIAL_ROUNDS;
1414

15-
static const uint32_t NUM_INITIAL_READS = 6;
16-
// static const uint32_t NUM_SIMPLE_ACCESSES = 7;
15+
inline constexpr uint32_t NUM_INITIAL_READS = 6;
16+
// inline constexpr uint32_t NUM_SIMPLE_ACCESSES = 7;
1717

1818
template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
1919
Poseidon2SubCols<T, WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS> inner;

extensions/rv32im/circuit/cuda/src/mulh.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ extern "C" int _mulh_tracegen(
203203
assert(height >= d_records.len());
204204
assert(width == sizeof(MulHCols<uint8_t>));
205205

206-
auto [grid, block] = kernel_launch_params(height);
206+
auto [grid, block] = kernel_launch_params(height, 512);
207207

208208
mulh_tracegen<<<grid, block>>>(
209209
d_trace,

0 commit comments

Comments
 (0)