@@ -58,9 +58,9 @@ static const int kTestEventCount =
5858 (int )(sizeof (kTestEventTemplates ) / sizeof (kTestEventTemplates [0 ]));
5959
6060struct run_config {
61- int m_dim ;
62- int k_dim ;
63- int n_dim ;
61+ long m_dim ;
62+ long k_dim ;
63+ long n_dim ;
6464 int event_count ;
6565 int iterations ;
6666 useconds_t iteration_delay_us ;
@@ -211,33 +211,58 @@ static int parse_test_override(int *argc, char **argv) {
211211
212212 long dims [3 ] = {0 };
213213 for (int d = 0 ; d < 3 ; ++ d ) {
214- char * endptr = NULL ;
214+ const char * arg = argv [ i + 1 + d ] ;
215215 errno = 0 ;
216- long value = strtol (argv [i + 1 + d ], & endptr , 10 );
217- if (errno != 0 || !endptr || * endptr != '\0' ) {
216+ char * endptr = NULL ;
217+ long value = strtol (arg , & endptr , 10 );
218+ if (endptr == arg || !endptr || * endptr != '\0' ) {
218219 fprintf (stderr , "Invalid integer value for %s argument: %s\n" ,
219- request_rocblas ? "--testrocblas" : "--test" , argv [i + 1 + d ]);
220+ request_rocblas ? "--testrocblas" : "--test" , arg );
221+ return -1 ;
222+ }
223+ if (errno == ERANGE ) {
224+ fprintf (stderr , "Dimension for %s is out of range: %s\n" ,
225+ request_rocblas ? "--testrocblas" : "--test" , arg );
220226 return -1 ;
221227 }
222- if (value <= 0 || value > INT_MAX ) {
223- fprintf (stderr , "Dimension for %s must be between 1 and %d : %s\n" ,
224- request_rocblas ? "--testrocblas" : "--test" , INT_MAX , argv [ i + 1 + d ] );
228+ if (value <= 0 ) {
229+ fprintf (stderr , "Dimension for %s must be a positive integer : %s\n" ,
230+ request_rocblas ? "--testrocblas" : "--test" , arg );
225231 return -1 ;
226232 }
227233 dims [d ] = value ;
228234 }
229235
230- g_run_config .m_dim = (int )dims [0 ];
231- g_run_config .k_dim = (int )dims [1 ];
232- g_run_config .n_dim = (int )dims [2 ];
236+ if (request_rocblas ) {
237+ for (int d = 0 ; d < 3 ; ++ d ) {
238+ if (dims [d ] > INT_MAX ) {
239+ fprintf (stderr , "rocBLAS test requires dimensions to be <= %d: %s\n" ,
240+ INT_MAX , argv [i + 1 + d ]);
241+ return -1 ;
242+ }
243+ }
244+ }
245+
246+ long m = dims [0 ];
247+ long k = dims [1 ];
248+ long n = dims [2 ];
249+ if ((m > LONG_MAX / k ) || (n > LONG_MAX / k ) || (m > LONG_MAX / n )) {
250+ fprintf (stderr , "Input dimensions are too large; ensure M*K, N*K, and M*N are <= %ld.\n" ,
251+ LONG_MAX );
252+ return -1 ;
253+ }
254+
255+ g_run_config .m_dim = m ;
256+ g_run_config .k_dim = k ;
257+ g_run_config .n_dim = n ;
233258 g_run_config .event_count = kTestEventCount ;
234259 g_run_config .iterations = ITERATIONS_PER_STREAM ;
235260 g_run_config .iteration_delay_us = 5000000 ; /* 5 second pause between iterations */
236261 g_run_config .test_mode = true;
237262 g_run_config .use_rocblas = request_rocblas ;
238263 g_run_config .device_index = 0 ;
239264 snprintf (g_run_config .csv_path , sizeof (g_run_config .csv_path ),
240- "amdsmi_gemm_%d_%d_%d %s.csv" ,
265+ "amdsmi_gemm_%ld_%ld_%ld %s.csv" ,
241266 g_run_config .m_dim ,
242267 g_run_config .k_dim ,
243268 g_run_config .n_dim ,
@@ -352,35 +377,42 @@ static void *monitor_events(void *args) {
352377 * A: MxK, B: KxN, C: MxN (row-major)
353378 */
354379__global__ void dgemm_kernel (const double * A , const double * B , double * C ,
355- int M , int N , int K , double alpha , double beta ) {
356- int row = blockIdx .y * blockDim .y + threadIdx .y ;
357- int col = blockIdx .x * blockDim .x + threadIdx .x ;
380+ long M , long N , long K , double alpha , double beta ) {
381+ long row = ( long ) blockIdx .y * blockDim .y + threadIdx .y ;
382+ long col = ( long ) blockIdx .x * blockDim .x + threadIdx .x ;
358383
359384 if (row < M && col < N ) {
360385 double sum = 0.0 ;
361- for (int k = 0 ; k < K ; k ++ ) {
362- sum += A [row * K + k ] * B [k * N + col ];
386+ long row_offset = row * K ;
387+ for (long k = 0 ; k < K ; ++ k ) {
388+ sum += A [row_offset + k ] * B [k * N + col ];
363389 }
364- C [row * N + col ] = alpha * sum + beta * C [row * N + col ];
390+ long c_index = row * N + col ;
391+ C [c_index ] = alpha * sum + beta * C [c_index ];
365392 }
366393}
367394
368395static int rocblas_dgemm_row_major (rocblas_handle handle ,
369396 const double * dA ,
370397 const double * dB ,
371398 double * dC ,
372- int M ,
373- int N ,
374- int K ,
399+ long M ,
400+ long N ,
401+ long K ,
375402 double alpha ,
376403 double beta ) {
377404 if (!handle ) {
378405 return -1 ;
379406 }
380407
381- const int m_prime = N ;
382- const int n_prime = M ;
383- const int k_prime = K ;
408+ if (M > INT_MAX || N > INT_MAX || K > INT_MAX ) {
409+ fprintf (stderr , "rocBLAS dgemm requires dimensions <= %d.\n" , INT_MAX );
410+ return -1 ;
411+ }
412+
413+ const int m_prime = (int )N ;
414+ const int n_prime = (int )M ;
415+ const int k_prime = (int )K ;
384416
385417 const int lda = m_prime ;
386418 const int ldb = k_prime ;
@@ -458,9 +490,9 @@ static int real_main(const HarnessOpts *opts) {
458490 SKIP ("Unable to locate the amd_smi component (PAPI built without ROCm?)" );
459491 }
460492
461- const int m_dim = g_run_config .m_dim ;
462- const int k_dim = g_run_config .k_dim ;
463- const int n_dim = g_run_config .n_dim ;
493+ const long m_dim = g_run_config .m_dim ;
494+ const long k_dim = g_run_config .k_dim ;
495+ const long n_dim = g_run_config .n_dim ;
464496 const int event_count = g_run_config .event_count ;
465497 const int iterations = g_run_config .iterations ;
466498 const useconds_t pause_us = g_run_config .iteration_delay_us ;
@@ -716,8 +748,8 @@ static int real_main(const HarnessOpts *opts) {
716748 double beta = 0.5 ;
717749
718750 dim3 blockDim (32 , 32 );
719- dim3 gridDim ((n_dim + blockDim .x - 1 ) / blockDim .x ,
720- (m_dim + blockDim .y - 1 ) / blockDim .y );
751+ dim3 gridDim ( (unsigned int )(( n_dim + blockDim .x - 1 ) / blockDim .x ) ,
752+ (unsigned int )(( m_dim + blockDim .y - 1 ) / blockDim .y ) );
721753
722754 for (int iter = 0 ; iter < iterations ; ++ iter ) {
723755 for (int s = 0 ; s < NUM_STREAMS ; s ++ ) {
0 commit comments