Skip to content

Commit c2114ee

Browse files
authored
[release/2.7] support experimental CU carveout (#2700)
This includes both the initial carveout PR plus it's later fix. - initial PR pytorch#149466 - fix PR pytorch#164303
1 parent e8c4b1c commit c2114ee

File tree

2 files changed

+142
-14
lines changed

2 files changed

+142
-14
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <c10/core/ScalarType.h>
1818

1919
#ifdef USE_ROCM
20+
#include <c10/cuda/CUDAStream.h>
2021
#include <hipblaslt/hipblaslt-ext.hpp>
2122
// until hipblas has an API to accept flags, we must use rocblas here
2223
#include <hipblas/hipblas.h>
@@ -185,6 +186,64 @@ uint32_t _getAlignment(uintptr_t address) {
185186
}
186187
#endif
187188

189+
#ifdef USE_ROCM
190+
static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) {
191+
// 0 is default value, meaning full CUs i.e. no mask
192+
if (value == 0) {
193+
return at::cuda::getCurrentCUDAStream();
194+
}
195+
static int32_t last_value = 0;
196+
static hipStream_t stream;
197+
if (last_value == 0) {
198+
// first request, do nothing for this case
199+
}
200+
else if (last_value == value) {
201+
// stream was created previously and value hasn't changed
202+
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
203+
}
204+
else {
205+
// need a new stream and a previous stream exists, delete it
206+
AT_CUDA_CHECK(hipStreamDestroy(stream));
207+
}
208+
209+
// if we got here, we need to create a new stream
210+
int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
211+
// how many uint32_t do we need to cover all CUs, fill bitmask with 1
212+
uint32_t mask_size = static_cast<uint32_t>((CUs + 32 - 1) / 32);
213+
std::vector<uint32_t> mask(mask_size, uint32_t{0x00000000});
214+
// starting from lowest order bits, in 32-bit chunks
215+
// set bits to 0 based on how many CUs to carve out
216+
int32_t full_shifts = value / 32;
217+
int32_t remainder = value % 32;
218+
for (int32_t i = 0; i < full_shifts; i++) {
219+
mask[i] = uint32_t{0xffffffff};
220+
}
221+
mask[full_shifts] = uint32_t{0xffffffff} << (32 - remainder);
222+
223+
// finally, create masked stream
224+
AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0]));
225+
226+
last_value = value;
227+
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
228+
}
229+
230+
static void _syncCurrentWithCarveoutStream(hipStream_t stream, bool presync) {
231+
hipEvent_t event;
232+
AT_CUDA_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming));
233+
234+
auto current_stream = at::cuda::getCurrentCUDAStream();
235+
236+
if (presync) {
237+
AT_CUDA_CHECK(hipEventRecord(event, current_stream));
238+
AT_CUDA_CHECK(hipStreamWaitEvent(stream, event, 0));
239+
}
240+
else {
241+
AT_CUDA_CHECK(hipEventRecord(event, stream));
242+
AT_CUDA_CHECK(hipStreamWaitEvent(current_stream, event, 0));
243+
}
244+
}
245+
#endif
246+
188247
struct CublasLtWorkspace {
189248
CublasLtWorkspace() {
190249
size = at::cuda::getCUDABlasLtWorkspaceSize();
@@ -360,13 +419,20 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
360419
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
361420
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
362421
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
422+
auto stream = at::cuda::getCurrentCUDAStream();
363423
#ifndef USE_ROCM
364424
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
365425
computeDesc.setAttribute<int32_t>(
366426
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
367427
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
368428
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
369429
}
430+
#else
431+
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
432+
stream = _getCarveoutStream(
433+
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
434+
_syncCurrentWithCarveoutStream(stream, true);
435+
}
370436
#endif
371437
CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T);
372438
CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T);
@@ -430,7 +496,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
430496
&heuristicResult.algo,
431497
ltworkspace.ptr,
432498
ltworkspace.size,
433-
at::cuda::getCurrentCUDAStream());
499+
stream);
500+
#ifdef USE_ROCM
501+
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
502+
_syncCurrentWithCarveoutStream(stream, false);
503+
}
504+
#endif
434505
TORCH_CHECK(
435506
cublasStatus == CUBLAS_STATUS_SUCCESS,
436507
"CUDA error: ",
@@ -1295,13 +1366,20 @@ void gemm_and_bias(
12951366
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
12961367
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
12971368
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
1369+
auto stream = at::cuda::getCurrentCUDAStream();
12981370
#ifndef USE_ROCM
12991371
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
13001372
computeDesc.setAttribute<int32_t>(
13011373
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
13021374
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
13031375
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
13041376
}
1377+
#else
1378+
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
1379+
stream = _getCarveoutStream(
1380+
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
1381+
_syncCurrentWithCarveoutStream(stream, true);
1382+
}
13051383
#endif
13061384
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
13071385
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
@@ -1370,7 +1448,12 @@ void gemm_and_bias(
13701448
&heuristicResult.algo,
13711449
ltworkspace.ptr,
13721450
ltworkspace.size,
1373-
at::cuda::getCurrentCUDAStream());
1451+
stream);
1452+
#ifdef USE_ROCM
1453+
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
1454+
_syncCurrentWithCarveoutStream(stream, false);
1455+
}
1456+
#endif
13741457
TORCH_CHECK(
13751458
cublasStatus == CUBLAS_STATUS_SUCCESS,
13761459
"CUDA error: ",
@@ -1525,13 +1608,20 @@ void scaled_gemm(
15251608
if (result_scale_ptr != nullptr) {
15261609
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
15271610
}
1611+
auto stream = at::cuda::getCurrentCUDAStream();
15281612
#ifndef USE_ROCM
15291613
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
15301614
computeDesc.setAttribute<int32_t>(
15311615
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
15321616
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
15331617
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
15341618
}
1619+
#else
1620+
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
1621+
stream = _getCarveoutStream(
1622+
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
1623+
_syncCurrentWithCarveoutStream(stream, true);
1624+
}
15351625
#endif
15361626
#ifndef USE_ROCM
15371627
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
@@ -1570,7 +1660,6 @@ void scaled_gemm(
15701660
#endif // if CUDA_VERSION >= 12090
15711661
}
15721662

1573-
auto stream = c10::cuda::getCurrentCUDAStream();
15741663
CuBlasLtMatmulPreference preference;
15751664
auto ltworkspace = CublasLtWorkspace();
15761665
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
@@ -1657,6 +1746,11 @@ void scaled_gemm(
16571746
ltworkspace.ptr,
16581747
ltworkspace.size,
16591748
stream);
1749+
#ifdef USE_ROCM
1750+
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
1751+
_syncCurrentWithCarveoutStream(stream, false);
1752+
}
1753+
#endif
16601754
TORCH_CHECK(
16611755
cublasStatus == CUBLAS_STATUS_SUCCESS,
16621756
"CUDA error: ",
@@ -1710,13 +1804,20 @@ void int8_gemm(
17101804
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
17111805
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
17121806
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
1807+
auto stream = at::cuda::getCurrentCUDAStream();
17131808
#ifndef USE_ROCM
17141809
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
17151810
computeDesc.setAttribute<int32_t>(
17161811
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
17171812
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
17181813
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
17191814
}
1815+
#else
1816+
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
1817+
stream = _getCarveoutStream(
1818+
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
1819+
_syncCurrentWithCarveoutStream(stream, true);
1820+
}
17201821
#endif
17211822

17221823
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
@@ -1778,7 +1879,7 @@ void int8_gemm(
17781879
#else
17791880
0,
17801881
#endif
1781-
at::cuda::getCurrentCUDAStream());
1882+
stream);
17821883
TORCH_CHECK(
17831884
cublasStatus == CUBLAS_STATUS_SUCCESS,
17841885
"CUDA error: ",
@@ -1807,6 +1908,11 @@ void int8_gemm(
18071908
computeType,
18081909
" scaleType ",
18091910
scaleType);
1911+
#ifdef USE_ROCM
1912+
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
1913+
_syncCurrentWithCarveoutStream(stream, false);
1914+
}
1915+
#endif
18101916
}
18111917

18121918
template <>

test/test_matmul_cuda.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,6 @@ def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
849849
self.assertEqual(out_dtype, out_fp8.dtype)
850850
self.assertEqual(out_fp32, out_fp8.to(torch.float))
851851

852-
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
853852
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
854853
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
855854
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
@@ -878,15 +877,38 @@ def test_honor_sm_carveout(self) -> None:
878877
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
879878

880879
prof.export_chrome_trace(f.name)
881-
no_carveout, carveout_0, carveout_66, no_carveout_again = [
882-
math.prod(evt.get("args", {}).get("grid", []))
883-
for evt in json.load(open(f.name))["traceEvents"]
884-
if evt.get("cat", "") == "kernel"
885-
]
886-
887-
self.assertEqual(no_carveout, no_carveout_again)
888-
self.assertNotEqual(no_carveout, carveout_66)
889-
self.assertNotEqual(carveout_66, carveout_0)
880+
if torch.version.hip:
881+
events = [evt for evt in json.load(open(f.name))["traceEvents"] if evt.get("cat", "") == "kernel"]
882+
# events were returned out of order; need to be sorted on "ts" timestamp
883+
events = sorted(events, key=lambda x: x['ts'])
884+
# ROCm carveout is invisible except for kernels running slower on fewer CUs
885+
no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
886+
if True or not (no_carveout < carveout_66 and carveout_0 < carveout_66 and no_carveout_again < carveout_66):
887+
# something went wrong, print more info to help debug flaky test
888+
print("ROCm debug info for test_honor_sm_carveout")
889+
print("no_carveout", no_carveout)
890+
print("carveout_0", carveout_0)
891+
print("carveout_66", carveout_66)
892+
print("no_carveout_again", no_carveout_again)
893+
self.assertTrue(no_carveout < carveout_66)
894+
self.assertTrue(carveout_0 < carveout_66)
895+
self.assertTrue(no_carveout_again < carveout_66)
896+
# ROCm carveout will create new streams when enabled, and go back to the original stream when disabled
897+
no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
898+
self.assertTrue(no_carveout == no_carveout_again)
899+
self.assertTrue(no_carveout == carveout_0)
900+
self.assertTrue(no_carveout != carveout_66)
901+
self.assertTrue(carveout_0 != carveout_66)
902+
else:
903+
no_carveout, carveout_0, carveout_66, no_carveout_again = [
904+
math.prod(evt.get("args", {}).get("grid", []))
905+
for evt in json.load(open(f.name))["traceEvents"]
906+
if evt.get("cat", "") == "kernel"
907+
]
908+
909+
self.assertEqual(no_carveout, no_carveout_again)
910+
self.assertNotEqual(no_carveout, carveout_66)
911+
self.assertNotEqual(carveout_66, carveout_0)
890912

891913
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
892914
@parametrize("test_case_name", [

0 commit comments

Comments
 (0)