Skip to content

Commit ee1c000

Browse files
committed
above 90
1 parent 625557d commit ee1c000

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

examples/misc/mem_bw.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#include <cooperative_groups.h>
22
#include <fmt/core.h>
3+
#include "utils.h"
34

45
using namespace cooperative_groups;
56

6-
__global__ void direct_copy_optimized(int4 *output, int4 *input, size_t n) {
7+
__global__ void direct_copy_optimized(float4 *output, float4 *input, size_t n) {
78
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
89
const size_t stride = blockDim.x * gridDim.x;
910

@@ -12,7 +13,7 @@ __global__ void direct_copy_optimized(int4 *output, int4 *input, size_t n) {
1213
}
1314
}
1415

15-
bool check_equal(int *output, int *input, int n) {
16+
bool check_equal(float *output, float *input, int n) {
1617
for (int i = 0; i < n; i++) {
1718
if (output[i] != input[i]) {
1819
fmt::print("Not equal for {}, input: {} output: {}\n", i, input[i], output[i]);
@@ -24,18 +25,18 @@ bool check_equal(int *output, int *input, int n) {
2425

2526
int main() {
2627

27-
int n = 1 << 24;
28-
int blockSize = 1024;
28+
int n = 1 << 28;
29+
int blockSize = 256;
2930
int numSMs;
3031
cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, 0);
3132
// manual Grid_size
32-
int nBlocks_manual = 32 * numSMs;
33-
int *output, *data;
34-
cudaMallocManaged(&output, n * sizeof(int));
35-
cudaMallocManaged(&data, n * sizeof(int));
33+
float nBlocks_manual = min(1024 * numSMs, simple_cuda::ceil_div(n, blockSize));
34+
float *output, *data;
35+
cudaMallocManaged(&output, n * sizeof(float));
36+
cudaMallocManaged(&data, n * sizeof(float));
3637
std::fill_n(data, n, 1); // initialize data
3738

38-
direct_copy_optimized<<<nBlocks_manual, blockSize>>>(reinterpret_cast<int4*>(output), reinterpret_cast<int4*>(data), n);
39+
direct_copy_optimized<<<nBlocks_manual, blockSize>>>(reinterpret_cast<float4*>(output), reinterpret_cast<float4*>(data), n);
3940
cudaDeviceSynchronize();
4041

4142
auto eq = check_equal(output, data, n);

0 commit comments

Comments
 (0)