|
17 | 17 | #include <c10/core/ScalarType.h> |
18 | 18 |
|
19 | 19 | #ifdef USE_ROCM |
| 20 | +#include <c10/cuda/CUDAStream.h> |
20 | 21 | #include <hipblaslt/hipblaslt-ext.hpp> |
21 | 22 | // until hipblas has an API to accept flags, we must use rocblas here |
22 | 23 | #include <hipblas/hipblas.h> |
@@ -185,6 +186,64 @@ uint32_t _getAlignment(uintptr_t address) { |
185 | 186 | } |
186 | 187 | #endif |
187 | 188 |
|
| 189 | +#ifdef USE_ROCM |
| 190 | +static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) { |
| 191 | + // 0 is default value, meaning full CUs i.e. no mask |
| 192 | + if (value == 0) { |
| 193 | + return at::cuda::getCurrentCUDAStream(); |
| 194 | + } |
| 195 | + static int32_t last_value = 0; |
| 196 | + static hipStream_t stream; |
| 197 | + if (last_value == 0) { |
| 198 | + // first request, do nothing for this case |
| 199 | + } |
| 200 | + else if (last_value == value) { |
| 201 | + // stream was created previously and value hasn't changed |
| 202 | + return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device()); |
| 203 | + } |
| 204 | + else { |
| 205 | + // need a new stream and a previous stream exists, delete it |
| 206 | + AT_CUDA_CHECK(hipStreamDestroy(stream)); |
| 207 | + } |
| 208 | + |
| 209 | + // if we got here, we need to create a new stream |
| 210 | + int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; |
| 211 | + // how many uint32_t do we need to cover all CUs, fill bitmask with 1 |
| 212 | + uint32_t mask_size = static_cast<uint32_t>((CUs + 32 - 1) / 32); |
| 213 | + std::vector<uint32_t> mask(mask_size, uint32_t{0x00000000}); |
| 214 | + // starting from lowest order bits, in 32-bit chunks |
| 215 | + // set bits to 0 based on how many CUs to carve out |
| 216 | + int32_t full_shifts = value / 32; |
| 217 | + int32_t remainder = value % 32; |
| 218 | + for (int32_t i = 0; i < full_shifts; i++) { |
| 219 | + mask[i] = uint32_t{0xffffffff}; |
| 220 | + } |
| 221 | + mask[full_shifts] = uint32_t{0xffffffff} << (32 - remainder); |
| 222 | + |
| 223 | + // finally, create masked stream |
| 224 | + AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0])); |
| 225 | + |
| 226 | + last_value = value; |
| 227 | + return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device()); |
| 228 | +} |
| 229 | + |
| 230 | +static void _syncCurrentWithCarveoutStream(hipStream_t stream, bool presync) { |
| 231 | + hipEvent_t event; |
| 232 | + AT_CUDA_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming)); |
| 233 | + |
| 234 | + auto current_stream = at::cuda::getCurrentCUDAStream(); |
| 235 | + |
| 236 | + if (presync) { |
| 237 | + AT_CUDA_CHECK(hipEventRecord(event, current_stream)); |
| 238 | + AT_CUDA_CHECK(hipStreamWaitEvent(stream, event, 0)); |
| 239 | + } |
| 240 | + else { |
| 241 | + AT_CUDA_CHECK(hipEventRecord(event, stream)); |
| 242 | + AT_CUDA_CHECK(hipStreamWaitEvent(current_stream, event, 0)); |
| 243 | + } |
| 244 | +} |
| 245 | +#endif |
| 246 | + |
188 | 247 | struct CublasLtWorkspace { |
189 | 248 | CublasLtWorkspace() { |
190 | 249 | size = at::cuda::getCUDABlasLtWorkspaceSize(); |
@@ -360,13 +419,20 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { |
360 | 419 | CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); |
361 | 420 | computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa); |
362 | 421 | computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb); |
| 422 | + auto stream = at::cuda::getCurrentCUDAStream(); |
363 | 423 | #ifndef USE_ROCM |
364 | 424 | if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
365 | 425 | computeDesc.setAttribute<int32_t>( |
366 | 426 | CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, |
367 | 427 | at::cuda::getCurrentDeviceProperties()->multiProcessorCount - |
368 | 428 | at::globalContext()._SMCarveout_EXPERIMENTAL().value()); |
369 | 429 | } |
| 430 | +#else |
| 431 | + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
| 432 | + stream = _getCarveoutStream( |
| 433 | + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); |
| 434 | + _syncCurrentWithCarveoutStream(stream, true); |
| 435 | + } |
370 | 436 | #endif |
371 | 437 | CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T); |
372 | 438 | CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T); |
@@ -430,7 +496,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { |
430 | 496 | &heuristicResult.algo, |
431 | 497 | ltworkspace.ptr, |
432 | 498 | ltworkspace.size, |
433 | | - at::cuda::getCurrentCUDAStream()); |
| 499 | + stream); |
| 500 | +#ifdef USE_ROCM |
| 501 | + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
| 502 | + _syncCurrentWithCarveoutStream(stream, false); |
| 503 | + } |
| 504 | +#endif |
434 | 505 | TORCH_CHECK( |
435 | 506 | cublasStatus == CUBLAS_STATUS_SUCCESS, |
436 | 507 | "CUDA error: ", |
@@ -1295,13 +1366,20 @@ void gemm_and_bias( |
1295 | 1366 | computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); |
1296 | 1367 | cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; |
1297 | 1368 | computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); |
| 1369 | + auto stream = at::cuda::getCurrentCUDAStream(); |
1298 | 1370 | #ifndef USE_ROCM |
1299 | 1371 | if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
1300 | 1372 | computeDesc.setAttribute<int32_t>( |
1301 | 1373 | CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, |
1302 | 1374 | at::cuda::getCurrentDeviceProperties()->multiProcessorCount - |
1303 | 1375 | at::globalContext()._SMCarveout_EXPERIMENTAL().value()); |
1304 | 1376 | } |
| 1377 | +#else |
| 1378 | + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
| 1379 | + stream = _getCarveoutStream( |
| 1380 | + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); |
| 1381 | + _syncCurrentWithCarveoutStream(stream, true); |
| 1382 | + } |
1305 | 1383 | #endif |
1306 | 1384 | cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; |
1307 | 1385 | if (activation == GEMMAndBiasActivationEpilogue::RELU) { |
@@ -1370,7 +1448,12 @@ void gemm_and_bias( |
1370 | 1448 | &heuristicResult.algo, |
1371 | 1449 | ltworkspace.ptr, |
1372 | 1450 | ltworkspace.size, |
1373 | | - at::cuda::getCurrentCUDAStream()); |
| 1451 | + stream); |
| 1452 | +#ifdef USE_ROCM |
| 1453 | + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
| 1454 | + _syncCurrentWithCarveoutStream(stream, false); |
| 1455 | + } |
| 1456 | +#endif |
1374 | 1457 | TORCH_CHECK( |
1375 | 1458 | cublasStatus == CUBLAS_STATUS_SUCCESS, |
1376 | 1459 | "CUDA error: ", |
@@ -1525,13 +1608,20 @@ void scaled_gemm( |
1525 | 1608 | if (result_scale_ptr != nullptr) { |
1526 | 1609 | computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); |
1527 | 1610 | } |
| 1611 | + auto stream = at::cuda::getCurrentCUDAStream(); |
1528 | 1612 | #ifndef USE_ROCM |
1529 | 1613 | if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
1530 | 1614 | computeDesc.setAttribute<int32_t>( |
1531 | 1615 | CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, |
1532 | 1616 | at::cuda::getCurrentDeviceProperties()->multiProcessorCount - |
1533 | 1617 | at::globalContext()._SMCarveout_EXPERIMENTAL().value()); |
1534 | 1618 | } |
| 1619 | +#else |
| 1620 | + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
| 1621 | + stream = _getCarveoutStream( |
| 1622 | + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); |
| 1623 | + _syncCurrentWithCarveoutStream(stream, true); |
| 1624 | + } |
1535 | 1625 | #endif |
1536 | 1626 | #ifndef USE_ROCM |
1537 | 1627 | const int8_t fastAccuMode = use_fast_accum ? 1 : 0; |
@@ -1570,7 +1660,6 @@ void scaled_gemm( |
1570 | 1660 | #endif // if CUDA_VERSION >= 12090 |
1571 | 1661 | } |
1572 | 1662 |
|
1573 | | - auto stream = c10::cuda::getCurrentCUDAStream(); |
1574 | 1663 | CuBlasLtMatmulPreference preference; |
1575 | 1664 | auto ltworkspace = CublasLtWorkspace(); |
1576 | 1665 | preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size); |
@@ -1657,6 +1746,11 @@ void scaled_gemm( |
1657 | 1746 | ltworkspace.ptr, |
1658 | 1747 | ltworkspace.size, |
1659 | 1748 | stream); |
| 1749 | +#ifdef USE_ROCM |
| 1750 | + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
| 1751 | + _syncCurrentWithCarveoutStream(stream, false); |
| 1752 | + } |
| 1753 | +#endif |
1660 | 1754 | TORCH_CHECK( |
1661 | 1755 | cublasStatus == CUBLAS_STATUS_SUCCESS, |
1662 | 1756 | "CUDA error: ", |
@@ -1710,13 +1804,20 @@ void int8_gemm( |
1710 | 1804 | computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); |
1711 | 1805 | cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; |
1712 | 1806 | computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); |
| 1807 | + auto stream = at::cuda::getCurrentCUDAStream(); |
1713 | 1808 | #ifndef USE_ROCM |
1714 | 1809 | if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
1715 | 1810 | computeDesc.setAttribute<int32_t>( |
1716 | 1811 | CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, |
1717 | 1812 | at::cuda::getCurrentDeviceProperties()->multiProcessorCount - |
1718 | 1813 | at::globalContext()._SMCarveout_EXPERIMENTAL().value()); |
1719 | 1814 | } |
| 1815 | +#else |
| 1816 | + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
| 1817 | + stream = _getCarveoutStream( |
| 1818 | + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); |
| 1819 | + _syncCurrentWithCarveoutStream(stream, true); |
| 1820 | + } |
1720 | 1821 | #endif |
1721 | 1822 |
|
1722 | 1823 | CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1); |
@@ -1778,7 +1879,7 @@ void int8_gemm( |
1778 | 1879 | #else |
1779 | 1880 | 0, |
1780 | 1881 | #endif |
1781 | | - at::cuda::getCurrentCUDAStream()); |
| 1882 | + stream); |
1782 | 1883 | TORCH_CHECK( |
1783 | 1884 | cublasStatus == CUBLAS_STATUS_SUCCESS, |
1784 | 1885 | "CUDA error: ", |
@@ -1807,6 +1908,11 @@ void int8_gemm( |
1807 | 1908 | computeType, |
1808 | 1909 | " scaleType ", |
1809 | 1910 | scaleType); |
| 1911 | +#ifdef USE_ROCM |
| 1912 | + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { |
| 1913 | + _syncCurrentWithCarveoutStream(stream, false); |
| 1914 | + } |
| 1915 | +#endif |
1810 | 1916 | } |
1811 | 1917 |
|
1812 | 1918 | template <> |
|
0 commit comments