Skip to content

Commit 11ecd68

Browse files
author
Dong Jun Woun
committed
long m n k, mnk < LONG_MAX
1 parent 582731d commit 11ecd68

File tree

1 file changed

+63
-31
lines changed

1 file changed

+63
-31
lines changed

src/components/amd_smi/tests/amdsmi_gemm.c

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ static const int kTestEventCount =
5858
(int)(sizeof(kTestEventTemplates) / sizeof(kTestEventTemplates[0]));
5959

6060
struct run_config {
61-
int m_dim;
62-
int k_dim;
63-
int n_dim;
61+
long m_dim;
62+
long k_dim;
63+
long n_dim;
6464
int event_count;
6565
int iterations;
6666
useconds_t iteration_delay_us;
@@ -211,33 +211,58 @@ static int parse_test_override(int *argc, char **argv) {
211211

212212
long dims[3] = {0};
213213
for (int d = 0; d < 3; ++d) {
214-
char *endptr = NULL;
214+
const char *arg = argv[i + 1 + d];
215215
errno = 0;
216-
long value = strtol(argv[i + 1 + d], &endptr, 10);
217-
if (errno != 0 || !endptr || *endptr != '\0') {
216+
char *endptr = NULL;
217+
long value = strtol(arg, &endptr, 10);
218+
if (endptr == arg || !endptr || *endptr != '\0') {
218219
fprintf(stderr, "Invalid integer value for %s argument: %s\n",
219-
request_rocblas ? "--testrocblas" : "--test", argv[i + 1 + d]);
220+
request_rocblas ? "--testrocblas" : "--test", arg);
221+
return -1;
222+
}
223+
if (errno == ERANGE) {
224+
fprintf(stderr, "Dimension for %s is out of range: %s\n",
225+
request_rocblas ? "--testrocblas" : "--test", arg);
220226
return -1;
221227
}
222-
if (value <= 0 || value > INT_MAX) {
223-
fprintf(stderr, "Dimension for %s must be between 1 and %d: %s\n",
224-
request_rocblas ? "--testrocblas" : "--test", INT_MAX, argv[i + 1 + d]);
228+
if (value <= 0) {
229+
fprintf(stderr, "Dimension for %s must be a positive integer: %s\n",
230+
request_rocblas ? "--testrocblas" : "--test", arg);
225231
return -1;
226232
}
227233
dims[d] = value;
228234
}
229235

230-
g_run_config.m_dim = (int)dims[0];
231-
g_run_config.k_dim = (int)dims[1];
232-
g_run_config.n_dim = (int)dims[2];
236+
if (request_rocblas) {
237+
for (int d = 0; d < 3; ++d) {
238+
if (dims[d] > INT_MAX) {
239+
fprintf(stderr, "rocBLAS test requires dimensions to be <= %d: %s\n",
240+
INT_MAX, argv[i + 1 + d]);
241+
return -1;
242+
}
243+
}
244+
}
245+
246+
long m = dims[0];
247+
long k = dims[1];
248+
long n = dims[2];
249+
if ((m > LONG_MAX / k) || (n > LONG_MAX / k) || (m > LONG_MAX / n)) {
250+
fprintf(stderr, "Input dimensions are too large; ensure M*K, N*K, and M*N are <= %ld.\n",
251+
LONG_MAX);
252+
return -1;
253+
}
254+
255+
g_run_config.m_dim = m;
256+
g_run_config.k_dim = k;
257+
g_run_config.n_dim = n;
233258
g_run_config.event_count = kTestEventCount;
234259
g_run_config.iterations = ITERATIONS_PER_STREAM;
235260
g_run_config.iteration_delay_us = 5000000; /* 5 second pause between iterations */
236261
g_run_config.test_mode = true;
237262
g_run_config.use_rocblas = request_rocblas;
238263
g_run_config.device_index = 0;
239264
snprintf(g_run_config.csv_path, sizeof(g_run_config.csv_path),
240-
"amdsmi_gemm_%d_%d_%d%s.csv",
265+
"amdsmi_gemm_%ld_%ld_%ld%s.csv",
241266
g_run_config.m_dim,
242267
g_run_config.k_dim,
243268
g_run_config.n_dim,
@@ -352,35 +377,42 @@ static void *monitor_events(void *args) {
352377
* A: MxK, B: KxN, C: MxN (row-major)
353378
*/
354379
__global__ void dgemm_kernel(const double *A, const double *B, double *C,
355-
int M, int N, int K, double alpha, double beta) {
356-
int row = blockIdx.y * blockDim.y + threadIdx.y;
357-
int col = blockIdx.x * blockDim.x + threadIdx.x;
380+
long M, long N, long K, double alpha, double beta) {
381+
long row = (long)blockIdx.y * blockDim.y + threadIdx.y;
382+
long col = (long)blockIdx.x * blockDim.x + threadIdx.x;
358383

359384
if (row < M && col < N) {
360385
double sum = 0.0;
361-
for (int k = 0; k < K; k++) {
362-
sum += A[row * K + k] * B[k * N + col];
386+
long row_offset = row * K;
387+
for (long k = 0; k < K; ++k) {
388+
sum += A[row_offset + k] * B[k * N + col];
363389
}
364-
C[row * N + col] = alpha * sum + beta * C[row * N + col];
390+
long c_index = row * N + col;
391+
C[c_index] = alpha * sum + beta * C[c_index];
365392
}
366393
}
367394

368395
static int rocblas_dgemm_row_major(rocblas_handle handle,
369396
const double *dA,
370397
const double *dB,
371398
double *dC,
372-
int M,
373-
int N,
374-
int K,
399+
long M,
400+
long N,
401+
long K,
375402
double alpha,
376403
double beta) {
377404
if (!handle) {
378405
return -1;
379406
}
380407

381-
const int m_prime = N;
382-
const int n_prime = M;
383-
const int k_prime = K;
408+
if (M > INT_MAX || N > INT_MAX || K > INT_MAX) {
409+
fprintf(stderr, "rocBLAS dgemm requires dimensions <= %d.\n", INT_MAX);
410+
return -1;
411+
}
412+
413+
const int m_prime = (int)N;
414+
const int n_prime = (int)M;
415+
const int k_prime = (int)K;
384416

385417
const int lda = m_prime;
386418
const int ldb = k_prime;
@@ -458,9 +490,9 @@ static int real_main(const HarnessOpts *opts) {
458490
SKIP("Unable to locate the amd_smi component (PAPI built without ROCm?)");
459491
}
460492

461-
const int m_dim = g_run_config.m_dim;
462-
const int k_dim = g_run_config.k_dim;
463-
const int n_dim = g_run_config.n_dim;
493+
const long m_dim = g_run_config.m_dim;
494+
const long k_dim = g_run_config.k_dim;
495+
const long n_dim = g_run_config.n_dim;
464496
const int event_count = g_run_config.event_count;
465497
const int iterations = g_run_config.iterations;
466498
const useconds_t pause_us = g_run_config.iteration_delay_us;
@@ -716,8 +748,8 @@ static int real_main(const HarnessOpts *opts) {
716748
double beta = 0.5;
717749

718750
dim3 blockDim(32, 32);
719-
dim3 gridDim((n_dim + blockDim.x - 1) / blockDim.x,
720-
(m_dim + blockDim.y - 1) / blockDim.y);
751+
dim3 gridDim((unsigned int)((n_dim + blockDim.x - 1) / blockDim.x),
752+
(unsigned int)((m_dim + blockDim.y - 1) / blockDim.y));
721753

722754
for (int iter = 0; iter < iterations; ++iter) {
723755
for (int s = 0; s < NUM_STREAMS; s++) {

0 commit comments

Comments
 (0)