@@ -178,6 +178,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
178178#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
179179 info.devices [id].smpbo = prop.sharedMemPerBlock ;
180180 info.devices [id].cc = 100 *prop.major + 10 *prop.minor + CC_OFFSET_AMD;
181+ #elif defined(GGML_USE_MUSA)
182+ /* * TODO: MUSA arch should match CUDA 11.4 */
183+ info.devices [id].smpbo = prop.sharedMemPerBlockOptin ;
184+ // info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_MT;
185+ info.devices [id].cc = CC_AMPERE;
181186#else
182187 info.devices [id].smpbo = prop.sharedMemPerBlockOptin ;
183188 info.devices [id].cc = 100 *prop.major + 10 *prop.minor ;
@@ -1671,9 +1676,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
16711676 }
16721677 }
16731678#else
1674- #ifdef GGML_USE_MUSA
1675- GGML_ASSERT (false );
1676- #else // !GGML_USE_MUSA
16771679 if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
16781680 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
16791681 // use cublasGemmStridedBatchedEx
@@ -1716,7 +1718,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17161718 cu_compute_type,
17171719 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
17181720 }
1719- #endif // GGML_USE_MUSA
17201721#endif
17211722
17221723 if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
@@ -2637,6 +2638,11 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
26372638 {
26382639 FILE *logFile = fopen (" ggml_op_perf.log" , " a" );
26392640 fprintf (logFile, " ## compute stats for each op: ##################################################\n " );
2641+ fprintf (logFile, " >> cc = %d, vmm = %d, total_vram = %u\n " ,
2642+ ggml_cuda_info ().devices [cuda_ctx->device ].cc ,
2643+ ggml_cuda_info ().devices [cuda_ctx->device ].vmm ,
2644+ ggml_cuda_info ().devices [cuda_ctx->device ].total_vram
2645+ );
26402646 float total_time = 0 , total_count = 0 ;
26412647 for (int i = 0 ; i < GGML_OP_COUNT; ++i) {
26422648 total_count += op_stats[i][OP_COUNT];
0 commit comments