Skip to content

Commit a9d356f

Browse files
valarLipamd-ruitang3Copilot
authored
fix tuner (ROCm#1701)
* fix tuner * Update gradlib/gradlib/GemmTuner.py Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: amd-ruitang3 <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent c6965e6 commit a9d356f

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

csrc/py_itfs_cu/asm_gemm_a16w16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A,
254254
int gdy = (Mdim + SUBM - 1) / SUBM;
255255
int gdz = selectedksplit;
256256

257-
TORCH_CHECK(gdx <= 16, __func__, " gdx (", gdx, ") must be <= 16"); // 16 = 512/32
257+
TORCH_CHECK(gdy <= 16, __func__, " gdy (", gdy, ") must be <= 16"); // 16 = 512/32
258258

259259
// semaphore.fill_(selectedksplit);
260260
args.ptr_semaphore = (void*)semaphore.data_ptr<uint32_t>();

gradlib/gradlib/GemmTuner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
import torch.nn.functional as F
2525

2626
import aiter
27-
from aiter import dtypes, logger
27+
from aiter import dtypes, get_semaphore_workspace, logger
2828
from aiter.jit.core import AITER_CONFIG_GEMM_BF16, get_asm_dir
2929
from aiter.jit.utils.chip_info import get_cu_num, get_gfx
3030
from aiter.ops.shuffle import shuffle_weight
31+
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 as triton_gemm_a16w16
3132
from aiter.utility.base_tuner import GemmCommonTuner
3233
from aiter.utility.mp_tuner import mp_tuner
33-
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 as triton_gemm_a16w16
3434

3535
aiter.hipb_create_extension()
3636

@@ -59,10 +59,12 @@ def call_hipb_mm(
5959
def run_gemm_bf16_asm(
6060
inp, w, out, bias=None, splitK=None, kernelName=None, bpreshuffle=False
6161
):
62+
sema = get_semaphore_workspace(inp.device)
6263
return aiter.gemm_a16w16_asm(
6364
inp,
6465
w,
6566
out,
67+
sema,
6668
bias=bias,
6769
splitK=splitK,
6870
kernelName=kernelName,

0 commit comments

Comments
 (0)