diff --git a/.github/workflows/integration-tests-amd.yml b/.github/workflows/integration-tests-amd.yml index 7098c20c53c3..92f228b273b3 100644 --- a/.github/workflows/integration-tests-amd.yml +++ b/.github/workflows/integration-tests-amd.yml @@ -109,7 +109,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py - pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py cd python/test/unit pytest --capture=tee-sys -rfs -n 12 language runtime \ diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 4f839b752cab..594fab8f36a1 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -3c709802d31b5bc5ed3af8284b40593ff39b9eec +092b6e73e651469527662443b592f98f442ece72 diff --git a/fa/flash-attention.py b/fa/flash-attention.py new file mode 100644 index 000000000000..de81cbf8f9a2 --- /dev/null +++ b/fa/flash-attention.py @@ -0,0 +1,2139 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf + +Credits: +AMD Triton kernels team +OpenAI kernel team + +Currently only the forward kernel is supported, and contains these features: + +1) Fwd with causal masking +2) Arbitrary Q and KV sequence lengths +3) Arbitrary head sizes +4) Multi and grouped query attention +5) Variable sequence lengths +6) ALiBi and matrix bias + +""" + +import argparse +import subprocess +import pytest +import sys +import torch + +import triton +import triton.language as tl +from utils.benchmark_utils import get_available_models, get_model_configs + + +class MetaData(): + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + persistent = None + num_contexts = 0 + varlen = False + int8 = False + layout = None + dropout_p, return_encoded_softmax = 0.0, False + + def __init__(self, sm_scale=1.0): + self.sm_scale = sm_scale + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.layout = 'thd' + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) + self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) + + def set_persistent(self, persistent): + self.persistent = persistent + + def set_int8_params(self, q_descale, k_descale, v_descale, p_scale, p_descale): + self.int8 = True + self.q_descale = q_descale + self.k_descale = k_descale + self.v_descale = v_descale + self.p_scale = p_scale + self.p_descale = p_descale + self.use_p_scale = (p_scale is not None) and (p_descale is not None) and (v_descale is not None) + self.int8_kv = (q_descale is None) and (k_descale is not None) and (v_descale is not None) + + def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def need_dropout(self, dropout_p, return_encoded_softmax): + self.dropout_p = dropout_p + self.return_encoded_softmax = return_encoded_softmax + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, self) + if self.varlen: + assert q.dim() == 3 + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + # TODO:Remove once dropout is supported with varlen + assert self.dropout_p == 0.0 + assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + if self.int8: + if self.int8_kv: + assert v.dtype == k.dtype and k.dtype == torch.int8 + assert q.dtype != k.dtype + assert (self.v_descale is not None) and (self.k_descale is not None) + else: + assert q.dtype == k.dtype and q.dtype == v.dtype and q.dtype == torch.int8 + assert (self.q_descale is not None) and (self.k_descale is not None) and (self.v_descale is not None) + if self.use_p_scale: + assert (self.p_scale is not None) and (self.p_descale is not None) + else: + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. +@triton.jit +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor + + +@triton.jit +def print_gpu(prefix, val=None): + if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): + if val is not None: + tl.device_print(prefix, val) + else: + tl.device_print(prefix) + + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_descale, + k_descale, v_descale, p_scale, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, + QK_SCALE: tl.constexpr, INT8_GEMM: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr, + ENABLE_PIPELINING: tl.constexpr): + # loop over k, v, and update accumulator + num_stages: tl.constexpr = None if ENABLE_PIPELINING else 1 # Set num_stages==1 if we want to disable pipelining + for start_n in tl.range(block_min, block_max, BLOCK_N, num_stages=num_stages): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) + if PRE_LOAD_V: + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + if INT8_GEMM: + qk += ((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale + else: + if INT8_KV: + k = (k * k_descale).to(q.type.element_ty) + qk += tl.dot(q, k) + + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089 / QK_SCALE) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, + global_n_positions) + qk += (alibi_block * 1.44269504089 / QK_SCALE) # scale factor of log2(e) + + # softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + m_ij_scaled = m_ij * QK_SCALE + qk = qk * QK_SCALE - m_ij_scaled[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i * QK_SCALE - m_ij_scaled) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + + if INT8_GEMM: + if USE_P_SCALE: + p = (p * p_scale).to(tl.int8) + # They are all int8 + acc += tl.dot(p, v) + else: + # v is in int8 but p is not, we want the gemm in p's type + acc += tl.dot(p, v.to(p.type.element_ty)) + else: + if INT8_KV: + v = (v * v_descale).to(p.type.element_ty) + acc += tl.dot(p.to(v.type.element_ty), v) + + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn + if RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += BLOCK_N + return acc, l_i, m_i + + +def get_gfx_version(): + try: + # Run the rocminfo command + result = subprocess.run(['rocminfo'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + output = result.stdout + + # Parse the output to find the gfx version + for line in output.splitlines(): + line = line.strip() + if line.startswith("Name: gfx"): + gfx_version = line.split("Name:")[1].strip() + return gfx_version + except Exception as e: + print(f"Error: {e}") + return None + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx950', 'gfx940', 'gfx941', + 'gfx942', 'gfx90a', 'gfx908') + + +def is_rdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101", + "gfx1102", "gfx1200", "gfx1201") + + +def get_cdna_autotune_configs(): + return [ + # triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + # num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=4, num_warps=8), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + # num_stages=2, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + # num_stages=2, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + # num_stages=2, num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + + +def get_rdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + + +def get_autotune_configs(): + if is_rdna(): + return get_rdna_autotune_configs() + elif is_cdna(): + return get_cdna_autotune_configs() + else: + raise ValueError("Unknown Device Type") + + +autotune_configs, autotune_keys = get_autotune_configs() + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, + stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, + stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, Q_descale, + K_descale, P_scale, P_descale, V_descale, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, + PERSISTENT: tl.constexpr, PERSISTENT_DYNAMIC: tl.constexpr, atomic_counter, NUM_CU: tl.constexpr, + GRID_CU_MULTIP: tl.constexpr, B: tl.constexpr, philox_offset_base, encoded_softmax, alibi_slopes, + HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr, + INT8: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr): + + tl.assume(stride_qz >= 0) + tl.assume(stride_qh >= 0) + tl.assume(stride_qm >= 0) + tl.assume(stride_qk >= 0) + tl.assume(stride_kz >= 0) + tl.assume(stride_kh >= 0) + tl.assume(stride_kn >= 0) + tl.assume(stride_kk >= 0) + tl.assume(stride_bz >= 0) + tl.assume(stride_bh >= 0) + tl.assume(stride_bm >= 0) + tl.assume(stride_bn >= 0) + tl.assume(stride_vz >= 0) + tl.assume(stride_vh >= 0) + tl.assume(stride_vk >= 0) + tl.assume(stride_vn >= 0) + tl.assume(stride_oz >= 0) + tl.assume(stride_oh >= 0) + tl.assume(stride_om >= 0) + tl.assume(stride_on >= 0) + + if PERSISTENT: # if persistent, kernel loops over multiple tiles + NUM_WG = NUM_CU * GRID_CU_MULTIP # number of workgroups launched + num_tiles_per_head = tl.cdiv(MAX_SEQLENS_Q, BLOCK_M) # the number of work units (tiles) of a single head + num_tiles_per_sample = num_tiles_per_head * HQ # times the number of heads + num_tiles_total = num_tiles_per_sample * B # times the number of samples + if PERSISTENT_DYNAMIC: + tile_id = atomic_counter.atomic_add(1) # retuns the value BEFORE the atomic operation + else: + tile_id = tl.program_id(0) + else: # standard, kernel processes only one tile + tile_id = 0 + num_tiles_total = 1 + + while tile_id < num_tiles_total: # loops more than once only if PERSISTENT + if PERSISTENT: + # tile id basically tells us the Q block we are handling + off_z = tile_id // num_tiles_per_sample # at which batch sample are we + off_h_q = tile_id % num_tiles_per_sample // num_tiles_per_head # at which head are we inside the sample + start_m = tile_id % num_tiles_per_sample % num_tiles_per_head # at which tile are we inside the head + else: + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + continue_condition = True # as we can't have return statements inside while loop in Triton + + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + continue_condition = False + # return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + if continue_condition: + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to([BLOCK_M, BLOCK_DMODEL]) + # We still need to write 0s to the result + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be handled here too? + continue_condition = False + # return + + if continue_condition: + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + # Compute pointers for all the scale tensors used in this kernel. + + INT8_GEMM: tl.constexpr = INT8 & (not INT8_KV) + if INT8: + k_descale_ptrs = K_descale + off_h_k + v_descale_ptrs = V_descale + off_h_k + if not INT8_KV: + q_descale_ptrs = Q_descale + off_h_q + if USE_P_SCALE: + p_scale_ptrs = P_scale + off_h_q + p_descale_ptrs = P_descale + off_h_q + + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + if ENABLE_DROPOUT: + off_hz = off_z * HQ + off_h_q + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + if RETURN_ENCODED_SOFTMAX: + encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k + encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] + else: + encoded_sm_ptrs = None + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + + if INT8: + k_descale = tl.load(k_descale_ptrs) + v_descale = tl.load(v_descale_ptrs) + if not INT8_KV: + q_descale = tl.load(q_descale_ptrs) + else: + q_descale = None + if USE_P_SCALE: + p_scale = tl.load(p_scale_ptrs) + p_descale = tl.load(p_descale_ptrs) + else: + p_scale = None + p_descale = None + else: + q_descale = None + k_descale = None + v_descale = None + p_scale = None + p_descale = None + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, + stride_bn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, + batch_philox_offset, encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, alibi_slope, q_descale, k_descale, + v_descale, p_scale, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, ACTUAL_BLOCK_DMODEL, QK_SCALE, INT8_GEMM, USE_P_SCALE, + INT8_KV, True) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + if RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, seqlen_k, + seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, block_min, block_max, + offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_descale, k_descale, v_descale, + p_scale, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL, + QK_SCALE, INT8_GEMM, USE_P_SCALE, INT8_KV, False) + + if INT8 and not INT8_KV: + if USE_P_SCALE: + acc *= p_descale + acc *= v_descale + + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + if PERSISTENT: + if PERSISTENT_DYNAMIC: + tile_id = atomic_counter.atomic_add(1) + else: + tile_id += NUM_WG + else: + tile_id = num_tiles_total # break after single tile + + +@triton.jit +def _attn_bwd_preprocess( + Out, + DO, + Delta, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_doz, + stride_doh, + stride_dom, + stride_don, + seqlen_q, + head_dim, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + # off_n = tl.arange(0, D_HEAD) + off_m = tl.program_id(0) * BLOCK_M + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + o_offset = off_h * stride_oh + off_z * stride_oz + O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, head_dim), strides=(stride_om, stride_on), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + do_offset = off_h * stride_doh + off_z * stride_doz + DO_block_ptr = tl.make_block_ptr(base=DO + do_offset, shape=(seqlen_q, head_dim), strides=(stride_dom, stride_don), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + # load + # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + do = tl.load(DO_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back, shape (q.shape[0] * q.shape[1], q.shape[2]) + off_zh = off_z * num_h + off_h * 1 + # Check for OOB accesses + delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M) + overflow = off_m + BLOCK_M - seqlen_q + if overflow > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32) + mask = boundary > tl.arange(0, BLOCK_M) + tl.store(delta_ptrs, delta, mask=mask) + else: + tl.store(delta_ptrs, delta) + + +@triton.jit +def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + # offs_k = tl.arange(0, BLOCK_DMODEL) + QT_block_ptr = tl.make_block_ptr(base=Q, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_M1), order=(0, 1)) + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M1, BLOCK_DMODEL), order=(1, 0)) + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(QT_block_ptr) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + kqT = tl.dot(k, qT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True) + kqT += alibi_block * 1.44269504089 + + pT = tl.math.exp2(kqT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(DO_block_ptr) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) + DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) + return dk, dv + + +@triton.jit +def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + # offs_k = tl.arange(0, BLOCK_DMODEL) + KT_block_ptr = tl.make_block_ptr(base=K, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + VT_block_ptr = tl.make_block_ptr(base=V, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(KT_block_ptr) + qk = tl.dot(q, kT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n) + qk += alibi_block * 1.44269504089 + + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + vT = tl.load(VT_block_ptr) + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ.0. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) + VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, + # H = 16, N_CTX = 1024 + H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # offs_k = tl.arange(0, BLOCK_DMODEL) + + start_n = pid * BLOCK_N1 + # This assignment is important. It is what allows us to pick the diagonal + # blocks. Later, when we want to do the lower triangular, we update start_m + # after the first dkdv call. + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + + # load K and V: they stay in SRAM throughout the inner loop for dkdv. + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + + if USE_ALIBI: + a_offset = bhid + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + # compute dK and dV for blocks close to the diagonal that need to be masked + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True) + + # compute dK and dV for blocks that don't need masking further from the diagonal + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False) + + DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DV_block_ptrs, dv.to(v.dtype)) + + # Write back dK. + dk *= sm_scale + DK_block_ptrs = tl.make_block_ptr(base=DK, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DK_block_ptrs, dk.to(k.dtype)) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + Q_block_ptr = tl.make_block_ptr(base=Q, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, MASK_BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, MASK=True) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * BLOCK_N2, num_steps, MASK=False) + # Write back dQ. + DQ_block_ptr = tl.make_block_ptr(base=DQ, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + dq *= LN2 + tl.store(DQ_block_ptr, dq.to(q.dtype)) + + +def get_shape_from_layout(q, k, metadata): + if metadata.layout == 'thd': + nheads_q, nheads_k = q.shape[1], k.shape[1] + head_size = q.shape[-1] + batch = metadata.num_contexts + elif metadata.layout == 'bhsd': + batch, nheads_q, _, head_size = q.shape + nheads_k = k.shape[1] + elif metadata.layout == 'bshd': + batch, _, nheads_q, head_size = q.shape + nheads_k = k.shape[2] + else: + assert False, "Got unsupported layout." + return batch, nheads_q, nheads_k, head_size + + +# TODO: This can probably optimized to have fewer lines of code. +def get_strides_from_layout(q, k, v, o, metadata): + if metadata.layout == 'thd': + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + elif metadata.layout == 'bhsd': + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + elif metadata.layout == 'bshd': + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + else: + assert False, 'Got unsupported layout.' + return q_strides, k_strides, v_strides, o_strides + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, o, metadata: MetaData): + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (metadata.bias is not None): + assert (metadata.bias.numel() < 2**31) + + if o is None: + if not metadata.int8: + o = torch.empty_like(q, dtype=v.dtype) + else: + o = torch.empty_like(q, dtype=torch.float16) + + metadata.check_args(q, k, v, o) + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, metadata) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + + # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according + # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing + # only. This return holds no useful output aside from debugging. + if metadata.return_encoded_softmax: + encoded_softmax = torch.zeros((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, + dtype=torch.float32) + else: + encoded_softmax = None + + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if metadata.bias is not None: + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), metadata.bias.stride(2), + metadata.bias.stride(3)) + else: + bias_strides = (0, 0, 0, 0) + + if metadata.alibi_slopes is not None: + alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + if metadata.int8: + q_descale, k_descale, p_scale, p_descale, v_descale = metadata.q_descale, metadata.k_descale, metadata.p_scale, metadata.p_descale, metadata.v_descale + else: + q_descale = k_descale = p_scale = p_descale = v_descale = None + + # number of compute units available + NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count + + if metadata.persistent is not None: + grid = lambda META: (min(NUM_CU * META['GRID_CU_MULTIP'], + triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']) * nheads_q * batch), ) + else: + grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) + + atomic_counter = torch.zeros([1], device=q.device, dtype=torch.int32) + + # test_op_fwd(Z, x_vals_list[1], x_vals_list[2], N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): + + attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, + *bias_strides, *alibi_strides, q_descale, k_descale, p_scale, p_descale, v_descale, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, dropout_p=metadata.dropout_p, + philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, + alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k, + IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, + USE_BIAS=False if metadata.bias is None else True, + USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p + > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, INT8=metadata.int8, + USE_P_SCALE=metadata.int8 and metadata.use_p_scale, INT8_KV=metadata.int8 and metadata.int8_kv, + PERSISTENT=metadata.persistent is not None, PERSISTENT_DYNAMIC=metadata.persistent == "dynamic", + NUM_CU=NUM_CU, atomic_counter=atomic_counter, B=batch) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = metadata.sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes + ctx.dropout_p = metadata.dropout_p + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = metadata.return_encoded_softmax + return o, encoded_softmax, attn_fwd.best_config + + @staticmethod + def backward(ctx, *gradients): + do = gradients[0] + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + seqlen_q = q.shape[2] + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + # NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + assert N_CTX % PRE_BLOCK == 0 + delta = torch.empty_like(M) + _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] + # padded_head = (Lk != ctx.BLOCK_DMODEL) + grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) + _attn_bwd_preprocess[grid_preprocess]( + o, + do, + delta, + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + do.stride(0), + do.stride(1), + do.stride(2), + do.stride(3), + seqlen_q, + head_dim=Lk, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, + arg_k, + v, + ctx.sm_scale, + ctx.alibi_slopes, + do, + dq, + dk, + dv, + M, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + N_HEAD, + N_CTX, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + BLOCK_M1=BLOCK_M1, + BLOCK_N1=BLOCK_N1, + BLOCK_M2=BLOCK_M2, + BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + USE_ALIBI=False if ctx.alibi_slopes is None else True, + ) + + return dq, dk, dv, None, None + + +attention = _attention.apply + +INT8_MAX = 127 + + +def quantize_int8(tensor: torch.Tensor, dim) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + max_vals = tensor.abs().amax(dim=[i for i in range(tensor.dim()) if i != dim], keepdim=True) + + # Avoid division by zero + max_vals[max_vals == 0] = 1e-8 + + # Compute scale factors for each channel + scale = INT8_MAX / max_vals.to(torch.float32) + + # Quantize the tensor + tensor = tensor * scale + tensor = tensor.round_() + tensor.clamp_(-INT8_MAX, INT8_MAX) + tensor_quantized = tensor.to(torch.int8) + + return tensor_quantized, scale, 1 / scale + + +def quantize_input(q, k, v, input_metadata: MetaData, quantize_p=False, int8_kv=False): + assert not (quantize_p and int8_kv) + if input_metadata.layout == 'bhsd': + qunatization_dim = 1 + elif input_metadata.layout == 'bshd': + qunatization_dim = 2 + else: + assert False, 'Got unsupported tensor layout' + assert not (quantize_p and int8_kv) + + q_descale = None + if not int8_kv: + q, _, q_descale = quantize_int8(q, dim=qunatization_dim) + k, _, k_descale = quantize_int8(k, dim=qunatization_dim) + v, _, v_descale = quantize_int8(v, dim=qunatization_dim) + + # In real world use case, the p scale would be a parameter trained by the model. + p_scale = p_descale = None + # The p shape is always bhqk + if quantize_p: + _, nheads_q, _, _ = get_shape_from_layout(q, k, input_metadata) + p_scale = torch.full((1, nheads_q, 1, 1), 127, dtype=torch.float32, device="cuda") + p_descale = 1 / p_scale + + # We are not multiplying the scales togather to get qk_desale / o_descale e.g. + # qk_desale = q_descale * k_descale + # o_desale = p_descale * v_descale + # it results in very small fp e.g. 0,0002, losing precision. They are applied on the run. + input_metadata.set_int8_params(q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + # By default p_scaling is not enabled + p_scale=p_scale, p_descale=p_descale) + + return q, k, v + + +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, requires_grad=True): + torch.manual_seed(20) + + # Initialize q, k, v + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + else: + assert False, 'Got unsupported tensor layout' + q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=requires_grad) + k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=requires_grad) + v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=requires_grad) + + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + input_metadata.layout = layout + return q, k, v, input_metadata + + +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False): + torch.manual_seed(20) + + # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + if N_CTX_Q == N_CTX_K: + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = seqlens_q + else: + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) + else: + seqlens_q = torch.full((Z, ), N_CTX_Q // Z) + seqlens_k = torch.full((Z, ), N_CTX_K // Z) + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_q = cu_seqlens_q.to(device="cuda") + cu_seqlens_k = cu_seqlens_k.to(device="cuda") + + # Initialize q, k, v with variable lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 24, 1024, 1024, 64), + (1, 24, 6, 8192, 8192, 64), + (1, 4, 2, 16384, 16384, 128), + (2, 16, 4, 1020, 987, 128), + (2, 16, 4, 15498, 2, 128), + (2, 16, 2, 7, 16219, 64), + (4, 48, 12, 1, 1, 64), + (4, 48, 48, 1, 1, 128), + (4, 48, 24, 3, 3, 128), + (4, 48, 48, 1001, 990, 64), + (1, 8, 8, 8081, 7099, 64), + (1, 4, 4, 16330, 15989, 128), + (4, 4, 1, 1024, 1024, 33), + (4, 4, 2, 65, 1018, 65), + (4, 4, 4, 128, 128, 65), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) +def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): + torch.manual_seed(20) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + if causal: + input_metadata.need_causal() + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, HQ) + else: + alibi_slopes = None + + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention(q, k, v, o, input_metadata) + + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + ref_out = ref_out + 1 + # torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + print("✅ Triton and Torch match") + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 24, 1024, 1024, 64), + (1, 24, 6, 8192, 8192, 64), + (1, 4, 2, 16384, 16384, 128), + (2, 16, 4, 1020, 987, 128), + (2, 16, 4, 15498, 2, 128), + (2, 16, 2, 7, 16219, 64), + (4, 48, 12, 1, 1, 64), + (4, 48, 48, 1, 1, 128), + (4, 48, 24, 3, 3, 128), + (4, 48, 48, 1001, 990, 64), + (1, 8, 8, 8081, 7099, 64), + (1, 4, 4, 16330, 15989, 128), + (4, 4, 1, 1024, 1024, 33), + (4, 4, 2, 65, 1018, 65), + (4, 4, 4, 128, 128, 65), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) +@pytest.mark.parametrize('persistent', ['fixed', 'dynamic']) +def test_op_persistent_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, persistent, + dtype=torch.float16): + torch.manual_seed(20) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + if causal: + input_metadata.need_causal() + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, HQ) + else: + alibi_slopes = None + + input_metadata.set_persistent(persistent) + + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention(q, k, v, o, input_metadata) + + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 12, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 4, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('quantize_p', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_fwd_int8(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, quantize_p, layout, dtype=torch.float16): + torch.manual_seed(20) + + # Disable grad to save memeory it won't run into OOM on CI machine. + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, requires_grad=False) + if causal: + input_metadata.need_causal() + + o = torch.empty_like(q) + + q_quantized, k_quantized, v_quantized = quantize_input(q, k, v, input_metadata, quantize_p=quantize_p) + + tri_out, _, best_configs = attention(q_quantized, k_quantized, v_quantized, o, input_metadata) + + # Compute scores + q_descale, k_descale, v_descale = input_metadata.q_descale, input_metadata.k_descale, input_metadata.v_descale + scores = (torch.einsum('bhqd,bhkd->bhqk', q_quantized.half(), k_quantized.half()) * q_descale * + k_descale) * input_metadata.sm_scale + + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + + # Quantization with tiling + if quantize_p: + tile_size = best_configs.kwargs["BLOCK_N"] # We need the tiling to match Block_N to work + m_i = torch.full((Z, H, N_CTX_Q), float('-inf'), device='cuda', dtype=torch.float32) + acc = torch.zeros((Z, H, N_CTX_Q, D_HEAD), device='cuda', dtype=torch.float32) + l_i = torch.zeros_like(m_i) + + for i in range(0, N_CTX_K, tile_size): + qk_tile = scores[:, :, :, i:i + tile_size] + v_tile = v_quantized[:, :, i:i + tile_size] + m_ij = torch.max(m_i, torch.max(qk_tile, dim=-1).values) + qk_tile -= m_ij.unsqueeze(-1) + p_tile = torch.exp(qk_tile) + l_ij = torch.sum(p_tile, dim=-1) + p_tile = (p_tile * input_metadata.p_scale).to(torch.int8) + + alpha = torch.exp(m_i - m_ij) + # We need float here since both p and v are quantized. So they might overflow the fp16 range. + acc = acc * alpha.unsqueeze(-1) + torch.einsum('bhqk,bhkd->bhqd', p_tile.float(), v_tile.float()) + m_i = m_ij + l_i = alpha * l_i + l_ij + + l_recip = 1 / l_i.unsqueeze(-1) + acc = acc * input_metadata.p_descale * input_metadata.v_descale * l_recip + ref_out = acc.to(torch.float16) + else: + p = torch.softmax(scores, dim=-1) + ref_out = (torch.einsum('bhqk,bhkd->bhqd', p.float(), v_quantized.float()) * v_descale).to(torch.float16) + + if causal: + nan_mask = torch.isnan(ref_out) + ref_out[nan_mask] = 0 + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 12, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 4, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_fwd_int8_kv(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, dtype=torch.float16): + torch.manual_seed(20) + + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + if causal: + input_metadata.need_causal() + + o = torch.empty_like(q) + + _, k_quantized, v_quantized = quantize_input(q, k, v, input_metadata, int8_kv=True) + k_descale, v_descale = input_metadata.k_descale, input_metadata.v_descale + k_dequantized = (k_quantized * k_descale).half() + v_dequantized = (v_quantized * v_descale).half() + + tri_out, _, _ = attention(q, k_quantized, v_quantized, o, input_metadata) + + # Compute scores + scores = torch.einsum('bhqd,bhkd->bhqk', q, k_dequantized).float() * input_metadata.sm_scale + + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + + p = torch.softmax(scores, dim=-1) + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v_dequantized).to(torch.float16) + + if causal: + nan_mask = torch.isnan(ref_out) + ref_out[nan_mask] = 0 + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 12, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 4, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + # TODO: This config fails. Disabled until triaged and fixed. + # (4, 4, 113, 123, 1), + # (2, 16, 15498, 2, 128), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_bias', [True]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): + torch.manual_seed(20) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd') + if causal: + input_metadata.need_causal() + if use_bias: + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda") + input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) + else: + bias = None + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention(q, k, v, o, input_metadata) + # reference implementation:171 + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_bias: + scores += input_metadata.bias + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 8192, 64), (4, 48, 256, 64), (4, 48, 512, 64), + (4, 48, 1024, 64), (8, 48, 4096, 64), (4, 48, 8192, 64), + (4, 48, 128, 128), (4, 48, 4096, 128), (4, 48, 16384, 128), + (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) +@pytest.mark.parametrize('causal', [True, False]) +def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) + + tri_out = torch.empty_like(q) + ref_out = torch.empty_like(q) + + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) + attention(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), + (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), + (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), + (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128), + (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)]) +@pytest.mark.parametrize('causal', [False]) +def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype) + ref_out = torch.empty_like(q) + tri_out = torch.empty_like(q) + # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so the + # size aligns with Q. + k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(-1, -1, HQ // HK, -1) + v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1) + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + k_curr = k_ref[start_k:end_k] + k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) + v_curr = v_ref[start_k:end_k] + v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + attention(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (2, 48, 4096, 64), + (1, 16, 1024, 64), + (1, 16, 1024, 128), + #(1, 16, 8192, 63), + #(1, 16, 1022, 64), +]) +@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) +@pytest.mark.parametrize('torch_sdpa_test', [False, True]) +@pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('use_alibi', [False, True]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, + dtype=torch.float16): + pytest.skip() + torch.manual_seed(20) + if qseqlen_not_equal_kseqlen is not None: + seqlen_q = qseqlen_not_equal_kseqlen + else: + seqlen_q = N_CTX + seqlen_k = N_CTX + + if causal and ((N_CTX - 1) & N_CTX): + pytest.skip() + if causal and seqlen_q != seqlen_k: + pytest.skip() + + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = seqlen_q + input_metadata.max_seqlens_k = seqlen_k + + dropout_p = 0 + q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + o = torch.empty_like(q) + + if causal: + input_metadata.need_causal() + + if use_alibi and not torch_sdpa_test: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, H) + dout = torch.randn_like(q) + # reference implementation + if torch_sdpa_test: + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, + is_causal=causal, scale=sm_scale, + dropout_mask=None) + ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + else: + M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if use_alibi: + p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) + if causal: + p[:, :, M == 0] = float("-inf") + + p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # # triton implementation + tri_out, _, _ = attention(q, k, v, o, input_metadata) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # test + #print("reference") + #print(ref_dv) + #print("tri") + #print(tri_dv) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + # The current block size for gfx90a and gfx908 series is 64x64. This results in + # larger differences in float results due to rounding. + + if dtype == torch.bfloat16: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + if dtype == torch.float32: + ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + else: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + + RTOL = 0 + + torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) + + +def nonvarlen_benchmark_configs(): + configs = [ + # (16, 16, 16, 1024, 1024), + # (8, 16, 16, 2048, 2048), + # (4, 16, 16, 4096, 4096), + # (2, 16, 16, 8192, 8192), + # (1, 16, 16, 16384, 16384), + # (2, 48, 48, 1024, 1024), + # (2, 48, 48, 2048, 1024), + # (2, 48, 48, 4096, 8192), + # (2, 48, 48, 8192, 4096), + (2, 48, 48, 16384, 8192), + # (8, 16, 16, 1989, 15344), + # (4, 16, 16, 4097, 163), + # (2, 16, 16, 8122, 2159), + # (1, 16, 16, 16281, 7), + # (2, 48, 48, 1021, 1020), + # (2, 48, 48, 2001, 2048), + # (2, 48, 48, 3996, 9639), + # (2, 48, 48, 8181, 1021), + ] + return configs + + +def varlen_benchmark_configs(): + configs = [ + # (2, 16, 4, 1024, 1024), + # (8, 16, 2, 2048, 2048), + # (4, 16, 8, 4096, 4096), + # (2, 16, 4, 8192, 8192), + # (2, 16, 8, 16384, 16384), + # (2, 48, 12, 1024, 1024), + # (2, 48, 24, 2048, 2048), + # (2, 48, 8, 4096, 4096), + # (2, 48, 4, 8192, 8192), + (2, 48, 2, 16384, 16384), + # (2, 64, 32, 1024, 1024), + # (4, 64, 16, 2048, 2048), + # (4, 64, 8, 4096, 4096), + # (4, 64, 32, 8192, 8192), + # (4, 128, 16, 16384, 16384), + ] + return configs + + +def model_benchmark_configs(args): + config_file = args.model_configs + configs = get_model_configs(config_path=config_file, model_families=["llama3"], model=args.model) + fa_configs = [] + batch_size = args.b if args.b else 1 + + for model_name, config in configs.items(): + HQ = config["num_attention_heads"] + HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"] + N_CTX_Q = args.sq if args.sq else 8192 + N_CTX_K = args.sk if args.sk else N_CTX_Q + HEAD_DIM = config["hidden_size"] // HQ + fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K, HEAD_DIM)) + + return fa_configs + + +def run_benchmark(custom, args): + + dtype = arg_to_torch_dtype[args.dtype] + hk = args.hq if not args.hk else args.hk + sk = args.sq if not args.sk else args.sk + head_size = 128 if not args.d else args.d + mode = 'fwd' + x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] + causal = args.causal if not args.model else True + int8 = args.int8 + quantize_p = args.quantize_p and int8 + int8_kv = args.int8_kv and int8 + varlen = True if args.model else args.layout == 'thd' + configs = [] + plot_name = f'fused-attention-{mode}-d{head_size}-layout{args.layout}' + extra_args = {'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode} + if custom: + x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] + else: + if varlen: + x_vals_list = varlen_benchmark_configs() + else: + x_vals_list = nonvarlen_benchmark_configs() + + if mode == 'bwd': + # Only those with N_CTX_Q == N_CTX_K work + new_x = [] + for v in x_vals_list: + if v[-1] == v[-2]: + new_x.append(v) + x_vals_list = new_x + + if args.model: + x_vals_list = model_benchmark_configs(args) + x_names = ['model', 'BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K', 'D_HEAD'] + plot_name = f'fused-attention-{mode}-layout{args.layout}' + extra_args = {'dtype': dtype, 'causal': causal, 'mode': mode} + print_time = args.return_time + + line_vals = ['triton', 'torch'] # 'Time (ms)' if print_time else 'TFLOPS' + configs.append( + triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=line_vals, + line_names=line_vals, styles=[('green', '-'), ('red', '-')], + ylabel='Time (ms)' if print_time else 'TFLOPS', plot_name=plot_name, args=extra_args)) + + @triton.testing.perf_report(configs) + def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda", + model=None): + assert mode in ["fwd", "bwd"] + assert not (int8_kv and quantize_p) + warmup = 25 + rep = 100 + # TODO: Enable bias after testing. + # if use_bias: + # bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + # input_metadata.need_bias(bias, BATCH, H, N_CTX, N_CTX) + # else: + # bias = None + # bias = None + + # Bwd pass only supports causal=True right now + if mode == 'bwd': + causal = True + + flops_per_matmul = 0 + if varlen: + q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, + args.equal_seqlens) + for i in range(0, input_metadata.num_contexts): + seqlen_q = (input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i]).item() + seqlen_k = (input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i]).item() + # x2 in both cases for 2 GEMMs + if causal: + # If seqlen_q != seqlen_k then the causal mask ignores computation + # depending on which seqlen is larger. Either the lower triangle, or right triangle + # If seqlen_q is greater than seqlen_k, the lower triangle is non zero + # where the last row has seqlen_k valid element, the second last row has + # seqlen_k - 1 valid elements and so on until one element is valid in the + # seqlen_q - seqlen_k row, hence total valid elements are 1+2+...+seqlen_k + # which is seqlen_k*(seqlen_k+1)/2 + # If seqlen_q is less than seqlen_k, then we count the zero elements + # the first row has seqlen_q-1 zero elements, the second row has seqlen_q-2 + # zero elements and so on until the second last row has 1 zero element + # Total zero elements are 1+2+...+(seqlen_q-1) = seqlen_q*(seqlen_q-1)/2 + # Total non zero elements are seqlen_q*seqlen_k - (seqlen_q*(seqlen_q-1)/2) + valid_out_elements = ((seqlen_k**2 + seqlen_k) / 2) if seqlen_q > seqlen_k else \ + (seqlen_q * seqlen_k - ((seqlen_q**2 - seqlen_q) / 2)) + flops_per_matmul += valid_out_elements * HQ * D_HEAD * 2 + else: + flops_per_matmul += seqlen_q * seqlen_k * HQ * D_HEAD * 2 + else: + q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout) + if causal: + # Same calculation as if varlen/if causal above + valid_out_elements = ((N_CTX_K**2 + N_CTX_K) / 2) if N_CTX_Q > N_CTX_K else \ + (N_CTX_Q * N_CTX_K - ((N_CTX_Q**2 - N_CTX_Q) / 2)) + flops_per_matmul = 2.0 * BATCH * HQ * valid_out_elements * D_HEAD + else: + flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + if causal: + input_metadata.need_causal() + + if "triton" in provider: + o = torch.empty_like(q) + if int8: + q, k, v = quantize_input(q, k, v, input_metadata, quantize_p=quantize_p, int8_kv=int8_kv) + input_metadata.set_persistent(args.persistent) + fn = lambda: attention(q, k, v, o, input_metadata) + if mode == 'bwd': + o, _, _ = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + + elif "torch" in provider and args.layout in ["thd", "bhsd", "bshd"]: + # torch requires the layout to be (b (optional),...,h,s,d) + if args.layout in ["thd", "bshd"]: + q = q.transpose(-3, -2) + k = k.transpose(-3, -2) + v = v.transpose(-3, -2) + # check if GQA + HQ = q.shape[-3] + HK = k.shape[-3] + if HQ != HK: # TODO: sdpa(..., enable_gqa=True work) should work + k = k.repeat_interleave(q.size(-3) // k.size(-3), -3) + v = v.repeat_interleave(q.size(-3) // v.size(-3), -3) + + fn = lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal, scale=input_metadata.sm_scale) + else: + assert False, f"Unknown provider {provider} in flash-attention." + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + total_flops = 2 * flops_per_matmul + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + if print_time: + return ms + else: + return total_flops / ms * 1e-9 + + bench_flash_attention.run(save_path=".", print_data=True, show_plots=True) + + +def supported_layouts(): + layouts = \ + 'bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]' \ + 'bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]' \ + 'thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]' \ + 'This layout is sometimes called "varlen" or "grouped" layout.' + return layouts + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="Benchmark FlashAttention", + allow_abbrev=False, + ) + parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.") + + available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names + model_help = ( + "Model name to benchmark. Select from: [" + ", ".join(available_models) + + "]. Use 'all' to benchmark all models. Not providing runs the default benchmark script with custom configs.") + parser.add_argument('-model', type=str, default=None, help=model_help) + parser.add_argument("-b", type=int, default=0) + parser.add_argument("-hq", type=int, default=0) + parser.add_argument("-hk", type=int, default=0) + parser.add_argument("-sq", type=int, default=0) + parser.add_argument("-sk", type=int, default=0) + parser.add_argument("-equal_seqlens", action='store_true', default=False, + help='If specified, each context within the thd layout' \ + ' has same seqlen as sq and sk') + parser.add_argument("-d", type=int, default=0) + parser.add_argument("-causal", action='store_true', default=False) + parser.add_argument("-int8", action='store_true', default=False) + parser.add_argument("-quantize_p", action='store_true', default=False) + parser.add_argument("-int8_kv", action='store_true', default=False) + parser.add_argument("-dtype", default='fp16') + parser.add_argument("-return_time", action='store_true', default=False) + parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) + parser.add_argument( + "-persistent", nargs='?', const='fixed', choices=['fixed', 'dynamic'], default=None, + help="Enable persistent kernels. Use '-persistent dynamic' for dynamic scheduling of the tiles.") + return parser.parse_args() + + +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} + + +def main(): + args = parse_args() + custom_config = False + assert args.layout == 'thd' or not args.equal_seqlens or args.model, \ + "Equal sequence lengths arg must be used with the thd layout or a model config." + if args.hq or args.hk or args.d: + custom_config = True + assert args.b and args.hq and args.sq and args.d, \ + "If custom config is specified, please provide \ + all of batch, number of Q heads, Q sequence length \ + and head size." + + if args.model: + assert not (args.hq or args.hk or args.d), \ + "Specifying model fixes hq, hk and d already. Do not provide them!" + + assert args.dtype in arg_to_torch_dtype, \ + "Only fp16, bf16 and f32 types currently supported." + + if args.model: + print("Note: Model config sets causal masking and THD layout (varlen) by default.") + + run_benchmark(custom_config, args) + + +if __name__ == '__main__': + test_op_fwd(2, 48, 48, 16384, 8192, 128, False, False, 'bshd', dtype=torch.float16) + sys.exit(main()) diff --git a/fa/model_configs.json b/fa/model_configs.json new file mode 100644 index 000000000000..5f0c28cd0e23 --- /dev/null +++ b/fa/model_configs.json @@ -0,0 +1,42 @@ +{ + "llama3": { + "8B": { + "num_attention_heads": 32, + "num_key_value_heads": 8, + "hidden_size": 4096, + "intermediate_size": 14336, + "vocab_size": 128256 + }, + "70B": { + "num_attention_heads": 64, + "num_key_value_heads": 8, + "hidden_size": 8192, + "intermediate_size": 28672, + "vocab_size": 128256 + }, + "405B": { + "num_attention_heads": 128, + "num_key_value_heads": 8, + "hidden_size": 16384, + "intermediate_size": 53248, + "vocab_size": 128256 + } + }, + "mistral": { + "7B": { + "hidden_size": 4096, + "intermediate_size": 14336, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "vocab_size": 32000 + }, + "22B": { + "hidden_size": 6144, + "intermediate_size": 16384, + "num_attention_heads": 48, + "num_key_value_heads": 8, + "vocab_size": 32000 + } + + } +} diff --git a/fa/utils/__init__.py b/fa/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/fa/utils/benchmark_utils.py b/fa/utils/benchmark_utils.py new file mode 100644 index 000000000000..11c19bcd0c18 --- /dev/null +++ b/fa/utils/benchmark_utils.py @@ -0,0 +1,71 @@ +import os +import json + +# Base directory where configs are located +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) + + +def get_model_configs(config_path='model_configs.json', model_families=["llama3"], model="all"): + """ + Load model names from the configuration file. + + Args: + config_path (str): User-provided path to the configuration JSON file. + model_families (list): List of model family names to retrieve. + + Returns: + dict: A dictionary of available models and their configurations for the specified families. + """ + # Resolve config path relative to ./perf-kernels/ + config_path = os.path.join(BASE_DIR, config_path) + + with open(config_path, 'r') as f: + configs = json.load(f) + + # Extract models and their configurations for the specified families + filtered_configs = {} + + for family in model_families: + if family in configs: + # Check if model filtering is required + if model == "all": + # Include all models in the family + for model_size, model_configs in configs[family].items(): + filtered_configs[f"{family}-{model_size}"] = model_configs + else: + # Parse the model string (e.g., llama3_8B or llama3-8B) + delimiter = "_" if "_" in model else "-" + model_parts = model.split(delimiter) + + # Check if the family and size match + if len(model_parts) == 2 and model_parts[0] == family: + model_size = model_parts[1] + if model_size in configs[family]: + filtered_configs[f"{family}-{model_size}"] = configs[family][model_size] + + if not filtered_configs: + print(f"Warning: No models selected for families: {model_families} with filter: '{model}'") + + return filtered_configs + + +def get_available_models(config_file='model_configs.json', model_families=["llama3"]): + """ + Load model names from the configuration file. + + Args: + config_file (str): Path to the configuration JSON file. + model_families (list): List of model family names to retrieve. + + Returns: + list: A list of available models for the specified families. + """ + # Resolve config path relative to ./perf-kernels/ + config_path = os.path.join(BASE_DIR, config_file) + + with open(config_path, 'r') as f: + configs = json.load(f) + + models = [f"{family}-{model}" for family in model_families if family in configs for model in configs[family]] + + return models diff --git a/fa/utils/rocprof_benchmark.py b/fa/utils/rocprof_benchmark.py new file mode 100644 index 000000000000..edccb7f3d314 --- /dev/null +++ b/fa/utils/rocprof_benchmark.py @@ -0,0 +1,59 @@ +import subprocess +import os +import pandas as pd +from prettytable import PrettyTable + + +def run_profiling(triton_dir, batch_size, output_file): + command = [ + "rocprof", "--stats", "-o", output_file, "python", f"{triton_dir}/python/perf-kernels/MLA_decode_rope.py", "-B", + str(batch_size), "-dtype", "bf16", "-use_rope" + ] + subprocess.run(command, check=True) + + +def parse_profiling_output(output_file, kernel_names): + df = pd.read_csv(output_file) + results = {} + for kernel in kernel_names: + kernel_data = df[df['Name'].str.strip('"') == kernel] + if not kernel_data.empty: + results[kernel] = kernel_data['AverageNs'].iloc[0] / 1000.0 + else: + results[kernel] = None + + # Calculate sum of other kernels + other_kernels = df[~df['Name'].str.strip('"').isin(kernel_names)] + other_kernels_sum = other_kernels['AverageNs'].sum() / 1000.0 + results['other_kernels_sum'] = other_kernels_sum + + return results + + +def main(): + triton_dir = os.environ.get("TRITONDIR", "~/triton") # Default to ~/triton if not set + output_file = os.path.expanduser("~/profiling.csv") + kernel_names = ["_fwd_grouped_kernel_stage1_rope.kd", "_fwd_grouped_kernel_stage1.kd"] + batch_sizes = [1, 4, 32, 64, 128] + + results = {B: {} for B in batch_sizes} + for B in batch_sizes: + print(f"Running profiling for B={B}...") + run_profiling(triton_dir, B, output_file) + output_stats_file = os.path.expanduser("~/profiling.stats.csv") + kernel_results = parse_profiling_output(output_stats_file, kernel_names) + results[B] = kernel_results + + table = PrettyTable() + table.field_names = ["B"] + kernel_names + ["Other Kernels Sum (µs)"] + for B in batch_sizes: + row = [B] + [results[B].get(kernel, "N/A") + for kernel in kernel_names] + [results[B].get('other_kernels_sum', "N/A")] + table.add_row(row) + + print("\nProfiling Summary (in microseconds):") + print(table) + + +if __name__ == "__main__": + main() diff --git a/fa/utils/rotary_embedding.py b/fa/utils/rotary_embedding.py new file mode 100644 index 000000000000..a864710601f1 --- /dev/null +++ b/fa/utils/rotary_embedding.py @@ -0,0 +1,283 @@ +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py +"""Rotary Positional Embeddings.""" +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +# from vllm.model_executor.custom_op import CustomOp + +# from sglang.srt.layers.custom_op_util import register_custom_op + + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor(_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, dtype: torch.dtype, device) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + device: Optional[str] = "cuda", + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor) + self.device = device + super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float, + device=self.device)) * self.extrapolation_factor + inv_freq = (inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=self.device, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] + # (max_seq, 64). 32 sin, 32 cos + cos, sin = cos_sin.chunk(2, dim=-1) + + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key diff --git a/fa/utils/sglang_ref.py b/fa/utils/sglang_ref.py new file mode 100644 index 000000000000..f862aee48f1e --- /dev/null +++ b/fa/utils/sglang_ref.py @@ -0,0 +1,619 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for decoding. +It supports page size = 1. +""" + +# Adapted from +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +import logging + +import triton +import triton.language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +is_hip_ = is_hip() + +logger = logging.getLogger(__name__) + +# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. +logger.warning("The following error message 'operation scheduled before its operands' can be ignored.") + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + kv_indptr, + kv_indices, + Att_Out, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q, mask=mask_d, other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + kv_indices + cur_batch_kv_start_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[None, :]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +def _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + kv_indptr, + kv_indices, + num_kv_splits, + sm_scale, + logit_cap, +): + BLOCK = 64 + # [TODO] work around SGPR limit on MI3xx + if is_hip_: + BLOCK = 32 + NUM_KV_SPLITS = num_kv_splits + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] + + grid = (batch, head_num, NUM_KV_SPLITS) + kv_group_num = q.shape[1] // k_buffer.shape[1] + + if kv_group_num == 1: + num_warps = 4 + else: + num_warps = 2 + if is_hip_: + num_warps = 1 + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + + _fwd_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + kv_indptr, + kv_indices, + att_out, + q.stride(0), + q.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + NUM_KV_SPLITS=NUM_KV_SPLITS, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + ) + + +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + kv_indptr, + kv_indices, + Att_Out, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if BLOCK_H < kv_group_num: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]) + qpe = tl.load(Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + kv_indices + cur_batch_kv_start_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[:, None]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_dpe[:, None]) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + cur_head[:, None] * stride_mid_oh + split_kv_id * stride_mid_os + + offs_dv[None, :]) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + kv_indptr, + kv_indices, + num_kv_splits, + sm_scale, + logit_cap, +): + BLOCK = 32 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + if is_hip_ and Lk >= 576: + BLOCK = 16 + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[1] + + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + extra_kargs = {} + num_stages = 2 + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} + num_stages = 1 + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + kv_indptr, + kv_indices, + att_out, + q.stride(0), + q.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + logit_cap=logit_cap, + num_warps=4, + num_stages=num_stages, + Lk=Lk, + Lv=Lv, + **extra_kargs, + ) + + +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + O, + kv_indptr, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + kv_indptr, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + extra_kargs = {} + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + o, + kv_indptr, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + num_stages=2, + **extra_kargs, + ) + + +def decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap=0.0, +): + _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + kv_indptr, + kv_indices, + num_kv_splits, + sm_scale, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) + + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + kv_indptr, + kv_indices, + num_kv_splits, + sm_scale, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap=0.0, +): + assert num_kv_splits == attn_logits.shape[2] + kv_group_num = q.shape[1] // v_buffer.shape[1] + + if kv_group_num == 1: + # MHA + decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap, + ) + else: + # GQA/MQA/MLA + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap, + ) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 1159a48ca471..54930e45f8c8 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -376,9 +376,13 @@ class SharedMemoryObject { return types; } - SmallVector getStrides(triton::gpu::MemDescType memDesc, Location loc, - RewriterBase &rewriter) const { + SmallVector + getStrides(triton::gpu::MemDescType memDesc, Location loc, + RewriterBase &rewriter, + ArrayRef overwriteAllocSize = {}) const { auto allocShape = memDesc.getAllocShape(); + if (!overwriteAllocSize.empty()) + allocShape = overwriteAllocSize; auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA( memDesc.getEncoding(), allocShape); auto layoutOrder = triton::gpu::getOrder(memDesc); @@ -698,13 +702,15 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback); + std::function perVectorCallback, + bool forceLane0 = false, ArrayRef overwriteAllocSize = {}); [[nodiscard]] bool emitTransferBetweenRegistersAndShared( LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback); + std::function perVectorCallback, + bool forceLane0 = false, ArrayRef overwriteAllocSize = {}); SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, Type elemLlvmTy, diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h index 9aae78062324..e34e41e54c19 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -115,6 +115,11 @@ class CoarseSchedule { bool insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, bool includeArg, bool insertIfEarlier = false); + bool insertDepsOfOp( + Operation *op, bool includeArg, bool insertIfEarlier, + llvm::function_ref(Operation *)> + getStageClusterForOp); + void erase(Operation *op) { opToStageAndCluster.erase(op); } int count(Operation *op) { return opToStageAndCluster.count(op); } diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index adf8131b9263..0e0bde48ef52 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -34,8 +34,12 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_HIP_GLOBAL_PREFETCH", "TRITON_HIP_LOCAL_PREFETCH", "TRITON_HIP_USE_ASYNC_COPY", + "TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE", + "TRITON_HIP_ASYNC_COPY_OVERLAP", + "TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG", "TRITON_HIP_USE_BLOCK_PINGPONG", "TRITON_HIP_USE_IN_THREAD_TRANSPOSE", + "TRITON_HIP_ASYNC_FAST_SWIZZLE", "TRITON_LLVM_DEBUG_ONLY", "TRITON_ENABLE_ASAN", "TRITON_OVERRIDE_ARCH", diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 7edc5c45aa30..bfade5b0d823 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -3,6 +3,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" @@ -1043,6 +1044,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, CastOpAxisInfoVisitor>(); visitors.append(); visitors.append(); diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index 693d222f2f39..dac700dc1335 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -17,4 +17,5 @@ add_triton_library(TritonAnalysis TritonIR TritonGPUIR TritonNvidiaGPUIR + TritonAMDGPUIR ) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 473f79170240..5f0368401a16 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -294,8 +294,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)), offset); } - auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset, - LLVM::GEPNoWrapFlags::inbounds); + auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset); + vecAddr.setInbounds(true); return vecAddr; }; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index efe00265eb11..94f26772bf09 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -277,7 +277,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, const SharedMemoryObject &smemObj, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, Value regId, Value laneId, Value warpId, Value blockId, - Location loc, RewriterBase &rewriter) { + Location loc, RewriterBase &rewriter, + ArrayRef overwriteAllocSize) { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); StringAttr kBlock = str_attr("block"); @@ -292,7 +293,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, auto smemBase = smemObj.getBase(); auto smemOffsets = smemObj.getOffsets(); - auto smemStrides = smemObj.getStrides(sharedTy, loc, rewriter); + auto smemStrides = + smemObj.getStrides(sharedTy, loc, rewriter, overwriteAllocSize); Value smemOffset; // When loading or storing to shared memory, we consider two cases for // performance reasons: @@ -398,8 +400,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, smemOffset = b.sub(smemOffset, baseToAllocBaseDist); } auto ptrTy = smemBase.getType(); - auto vecAddr = b.gep(ptrTy, elemLlvmTy, smemBase, smemOffset, - LLVM::GEPNoWrapFlags::inbounds); + auto vecAddr = b.gep(ptrTy, elemLlvmTy, smemBase, smemOffset); + vecAddr.setInbounds(true); return vecAddr; } @@ -409,7 +411,8 @@ bool emitTransferBetweenRegistersAndShared( LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback) { + std::function perVectorCallback, + bool forceLane0, ArrayRef overwriteAllocSize) { MLIRContext *ctx = rewriter.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -452,6 +455,17 @@ bool emitTransferBetweenRegistersAndShared( auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1; auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + if (forceLane0) { + laneId = b.i32_val(0); + // NFC it's copied from getLaneAndWarpId but we add a shuffleIdx(0) to the + // tid so LLVM sees that warpId is a scalar + // This is not optimal as it adds a readlane which is not necessary but + // better than getting readfirstlanes for every direct-to-lds load + Value tid = target.shuffleIdx(rewriter, loc, getThreadId(rewriter, loc), 0); + int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter); + Value warpSizeVal = b.i32_val(threadsPerWarp); + warpId = b.udiv(tid, warpSizeVal); + } Value blockId = withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0); @@ -473,9 +487,10 @@ bool emitTransferBetweenRegistersAndShared( SmallVector ret; for (int i = 0; i < numElems / vecElems; i++) { auto regId = b.i32_val(i * vecElems); - auto vecAddr = getSmemVecAddr( - regLayout, regToSharedLayout, invertAllocSharedLayout, smemObj, - sharedTy, elemLlvmTy, regId, laneId, warpId, blockId, loc, rewriter); + auto vecAddr = + getSmemVecAddr(regLayout, regToSharedLayout, invertAllocSharedLayout, + smemObj, sharedTy, elemLlvmTy, regId, laneId, warpId, + blockId, loc, rewriter, overwriteAllocSize); perVectorCallback(vecTy, vecAddr); } return true; @@ -486,12 +501,13 @@ bool emitTransferBetweenRegistersAndShared( Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback) { + std::function perVectorCallback, + bool forceLane0, ArrayRef overwriteAllocSize) { auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(), registerTy.getEncoding()); return emitTransferBetweenRegistersAndShared( regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter, - target, perVectorCallback); + target, perVectorCallback, forceLane0, overwriteAllocSize); } SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, @@ -502,11 +518,28 @@ SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, auto srcTy = localLoadOp.getSrc().getType(); auto dstTy = localLoadOp.getResult().getType(); + // We overwrite the alloc size if we are a subview to fix subviews in the + // fastest dim + SmallVector overwriteSmemAllocSize; + auto src = localLoadOp.getSrc(); + if (auto subView = src.getDefiningOp()) { + auto subViewSrcTy = + dyn_cast(subView.getSrc().getType()); + if (subViewSrcTy) { + auto origAllocSize = subViewSrcTy.getAllocShape(); + auto srcAllocSize = srcTy.getAllocShape(); + if (origAllocSize.size() == 3 && srcAllocSize.size() == 2) { + overwriteSmemAllocSize = to_vector(origAllocSize.drop_front()); + } + } + } + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector ret; bool success = emitTransferBetweenRegistersAndShared( dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, - rewriter, target, [&](VectorType vecTy, Value vecAddr) { + rewriter, target, + [&](VectorType vecTy, Value vecAddr) { auto vecVal = b.load(vecTy, vecAddr); target.localLoadOpAnnotation(localLoadOp, vecVal); vecVal.setAlignment(vecTy.getNumElements() * @@ -515,7 +548,8 @@ SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, for (int v = 0; v < vecTy.getNumElements(); v++) { ret.push_back(b.extract_element(elemLlvmTy, vecVal, b.i32_val(v))); } - }); + }, + false, overwriteSmemAllocSize); if (!success) llvm::report_fatal_error("Failed to emit transfer from shared to register"); diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 6fab87c8a562..22496cf67192 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -239,6 +239,29 @@ class RankedReduceDescriptorLoads : public mlir::OpRewritePattern { } }; +class CombineDotScaledAddPattern : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::DotScaledOp dotOp, + mlir::PatternRewriter &rewriter) const override { + if (!dotOp->hasOneUse() || !isZero(dotOp.getC())) + return failure(); + auto user = dotOp->getUsers().begin(); + if (auto addOp = llvm::dyn_cast(*user)) { + auto acc = (addOp.getRhs() == dotOp) ? addOp.getLhs() : addOp.getRhs(); + IRMapping mapping; + mapping.map(dotOp.getC(), acc); + auto newOp = rewriter.clone(*dotOp, mapping); + rewriter.replaceOp(addOp, newOp->getResults()); + rewriter.eraseOp(dotOp); + return success(); + } + return failure(); + } +}; + } // anonymous namespace class CombineOpsPass : public impl::TritonCombineOpsBase { @@ -253,6 +276,8 @@ class CombineOpsPass : public impl::TritonCombineOpsBase { patterns.add(context); patterns.add(context); patterns.add(context); + + patterns.add(context); // %} patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 216e4dc2efb1..f63fda14ecde 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1797,6 +1797,15 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand( int innerDimLength = operandShape[sharedOrder[0]]; int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth; + // This is a hack optimization for the V tensor shared layout, which + // - is not kContig + // - local_load from the tensor will have kWidth=4 + // - ds_read_tr is used + // In this case, we can set vecSize to nonkDim of the mfma instruction + // to avoid read bank conflicts + if (!isKContig) + vectorSize = getMDim(); + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); int maxPhase = std::max(std::min(simdWidth / perPhase, innerDimLength / vectorSize), 1u); @@ -1805,6 +1814,12 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand( if (getMDim() == 4) maxPhase = 4; + // Disable swizzling for scales + if (operandIdx >= 2) { + return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder, + ctaLayout); + } + return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase, maxPhase, sharedOrder, ctaLayout); } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index d1397549de27..43796875be51 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1537,7 +1537,7 @@ std::optional chooseMfmaLikeStoreLayout(RankedTensorType valType) { auto mfmaLayout = cast(valType.getEncoding()); - // Currently support transposed [B]F16 MFMA32x32 on CDNA4 + // We currently only support transposed [B]F16 MFMA32x32 on CDNA4. bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32; Type elemType = valType.getElementType(); if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) && @@ -1545,32 +1545,27 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) { isMfma32)) return {}; - MLIRContext *ctx = mfmaLayout.getContext(); - StringAttr kRegister = S("register"); - StringAttr kLane = S("lane"); - StringAttr kWarp = S("warp"); - StringAttr kBlock = S("block"); - - SmallVector order = getDefaultMmaOrder(mfmaLayout); - auto standardOutDims = standardOutDimNames(ctx, 2); - // We make each thread handle 8 consecutive elements to enable 128-bit - // global stores for [b]f16 types and keep the thread pattern in each lane - // similar to the canonical mfmaLayout. - LinearLayout mfma8Layout = LinearLayout::empty(); - mfma8Layout = - LinearLayout({{kRegister, {{1, 0}, {2, 0}, {4, 0}}}, - {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, - {kWarp, {}}, - {kBlock, {}}}, - {standardOutDims[order[0]], standardOutDims[order[1]]}); - - LinearLayout warpLayout = - identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order); - LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) * - warpLayout.transposeOuts(standardOutDims); - mfma8Layout = combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), - valType.getShape()); - return mfma8Layout; + auto valShape = valType.getShape(); + LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape); + auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames()); + StringAttr dimM = mfmaOutDims[0]; + StringAttr dimN = mfmaOutDims[1]; + + auto swapLL = LinearLayout::empty(); + // The rows are kept as is with an identity linear layout. + swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM); + // In transposed mfma32 layout, each thread holds 4 consecutive values along N + // dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11 + // (owned by thread 0-31) every 16 columns to make each thread holds 8 + // elements. This would mean exchange the 2nd and 3rd basis vector from an + // identity linear layout. + std::vector> dimNBases(mfmaLL.getOutDimSizeLog2(dimN)); + std::generate(dimNBases.begin(), dimNBases.end(), + [i = 0]() mutable { return std::vector{1 << i++}; }); + std::swap(dimNBases[2], dimNBases[3]); + swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN}); + + return mfmaLL.compose(swapLL); } LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType, diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp index aafbd5e8e8ac..3227af618ff2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -53,8 +53,17 @@ bool tt::CoarseSchedule::insertMinimum(Operation *op, int stage, bool tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, tt::CoarseSchedule::Cluster cluster, bool includeArg, bool insertIfEarlier) { - auto tryInsert = [&](Operation *op, int stage, - tt::CoarseSchedule::Cluster cluster) { + auto func = [=](Operation *) { return std::make_pair(stage, cluster); }; + return insertDepsOfOp(op, includeArg, insertIfEarlier, func); +} + +bool tt::CoarseSchedule::insertDepsOfOp( + Operation *op, bool includeArg, bool insertIfEarlier, + llvm::function_ref(Operation *)> + getStageAndClusterForOp) { + auto tryInsert = [&insertIfEarlier, + this](Operation *op, int stage, + tt::CoarseSchedule::Cluster cluster) { if (!insertIfEarlier) return insertIfAbsent(op, stage, cluster); return insertMinimum(op, stage, cluster); @@ -78,9 +87,11 @@ bool tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, } Operation *defOp = v.getDefiningOp(); if (defOp && defOp->getBlock() == op->getBlock()) { - if (tryInsert(defOp, stage, cluster)) { + auto [defStage, defCluster] = getStageAndClusterForOp(defOp); + if (tryInsert(defOp, defStage, defCluster)) { inserted = true; - insertDepsOfOp(defOp, stage, cluster, includeArg, insertIfEarlier); + insertDepsOfOp(defOp, includeArg, insertIfEarlier, + getStageAndClusterForOp); } } } diff --git a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir index 04fc1397d626..2892fcec625c 100644 --- a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir +++ b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir @@ -145,10 +145,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr %arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>) { %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked> // The first constant 0 skips the LDS offset which is also 0 - // COMMON: llvm.getelementptr + // COMMON: rocdl.make.buffer.rsrc + // COMMON: llvm.select // COMMON: llvm.mlir.constant(0 : i32) : i32 // COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32 - // COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]] + // COMMON: llvm.mlir.constant(0 : i32) : i32 + // COMMON-: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]] %1 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = ca into %arg2: [tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable> // COMMON: llvm.getelementptr // COMMON: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32 diff --git a/test/Conversion/amd/invalid_concat_op.mlir b/test/Conversion/amd/invalid_concat_op.mlir new file mode 100644 index 000000000000..2b359dc059ed --- /dev/null +++ b/test/Conversion/amd/invalid_concat_op.mlir @@ -0,0 +1,174 @@ +// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics + + +// Invalid ranks +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{Source and destination tensors must have the same rank.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +// Invalid shapes 1 +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{Source and destination tensor shapes don't match.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<257x128xf32, #blocked> + tt.return + } +} + +// ----- + +// Invalid shapes 2 +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{Number of source tiles (8) doesn't match required count (16).}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<256x128xf32, #blocked> + tt.return + } +} + + +// ----- + +// Invalid shapes 3 +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{CTA tile shapes must match between source and destination tensors.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<128x128xf32, #blocked1> + tt.return + } +} + +// ----- + +// Different types +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked1>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{All sources must have identical tensor types.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked1>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<128x128xf32, #blocked> + tt.return + } +} + +// ----- + +// Invalid element types +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{Element types of sources and destination must match.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<256x64xf16, #blocked> + tt.return + } +} + + +// ----- + +// Different layouts 1 +#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4], [0, 0]], warp=[[0, 32], [32, 0]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<128x128xf32, #src_layout>, + %arg1: tensor<128x128xf32, #src_layout>, + %arg2: tensor<128x128xf32, #src_layout>, + %arg3: tensor<128x128xf32, #src_layout>) { + + // expected-error @+1 {{Lane and warp dim basis must match between source and destination layout.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout> + tt.return + } +} + +// ----- + +// Different layouts 2 +#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +#dst_layout = #ttg.linear<{register=[[0, 0], [0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<128x128xf32, #src_layout>, + %arg1: tensor<128x128xf32, #src_layout>, + %arg2: tensor<128x128xf32, #src_layout>, + %arg3: tensor<128x128xf32, #src_layout>) { + + // expected-error @+1 {{Register basis must match on a CTA tile between source and destination.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout> + tt.return + } +} diff --git a/test/Conversion/cvt_to_llvm.mlir b/test/Conversion/cvt_to_llvm.mlir index 5ec73e4c8a32..f577bc5af53e 100644 --- a/test/Conversion/cvt_to_llvm.mlir +++ b/test/Conversion/cvt_to_llvm.mlir @@ -48,7 +48,7 @@ tt.func private @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xi32, #bl // CHECK-DAG: [[X_MOD_2:%.*]] = and i32 [[TID]], 1 // CHECK-DAG: [[X_2_4_LOWER:%.*]] = shl {{.*}} i32 [[IS_UPPER_HALF]], 1 - // CHECK-DAG: [[X_2_4_UPPER0:%.*]] = shl {{.*}} i32 [[TID]], 1 + // CHECK-DAG: [[X_2_4_UPPER0:%.*]] = shl i32 [[TID]], 1 // CHECK-DAG: [[X_2_4_UPPER1:%.*]] = and i32 [[X_2_4_UPPER0]], 24 // CHECK-DAG: [[X_GE_16:%.*]] = and i32 [[TID]], 16 // CHECK-DAG: [[X_GE_16_2:%.*]] = lshr exact i32 [[X_GE_16]], 3 diff --git a/test/TritonGPU/amd/amd-concat-op.mlir b/test/TritonGPU/amd/amd-concat-op.mlir new file mode 100644 index 000000000000..715b32587bd2 --- /dev/null +++ b/test/TritonGPU/amd/amd-concat-op.mlir @@ -0,0 +1,105 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s + +// ----- + +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @concat_blocked( + %arg0: tensor<32x64xf32, #blocked1>, + %arg1: tensor<32x64xf32, #blocked1>, + %arg2: tensor<32x64xf32, #blocked1>, + %arg3: tensor<32x64xf32, #blocked1>, + %arg4: tensor<32x64xf32, #blocked1>, + %arg5: tensor<32x64xf32, #blocked1>, + %arg6: tensor<32x64xf32, #blocked1>, + %arg7: tensor<32x64xf32, #blocked1>) { + // CHECK: llvm.func @concat_blocked + + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg4[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg5[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg6[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg7[{{.*}}] : !llvm.struct + + // CHECK-COUNT-64: %{{[0-9]*}} = llvm.insertvalue %{{.*}} : !llvm.struct + + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked1>,tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1> -> tensor<128x128xf32, #blocked1> + tt.return + } +} + +// ----- + +#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @concat_ll_2d_1( + %arg0: tensor<128x128xf32, #src_layout>, + %arg1: tensor<128x128xf32, #src_layout>, + %arg2: tensor<128x128xf32, #src_layout>, + %arg3: tensor<128x128xf32, #src_layout>){ + // CHECK: llvm.func @concat_ll_2d_1 + + // CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct + // CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct + // CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct + // CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct + // CHECK-COUNT-256: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct + + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout> + tt.return + } +} + +// ----- + +#src_layout = #ttg.linear<{register=[[1, 0], [2, 0], [4, 0]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp=[[0, 16]], block=[]}> +#dst_layout = #ttg.linear<{register=[[1, 0], [2, 0], [4, 0], [32, 0], [0, 32]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp=[[0, 16]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @concat_ll_2d_2( + %arg0: tensor<32x32xf32, #src_layout>, + %arg1: tensor<32x32xf32, #src_layout>, + %arg2: tensor<32x32xf32, #src_layout>, + %arg3: tensor<32x32xf32, #src_layout>){ + // CHECK: llvm.func @concat_ll_2d_2 + + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct + // CHECK-COUNT-32: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct + + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<32x32xf32, #src_layout>, tensor<32x32xf32, #src_layout>, tensor<32x32xf32, #src_layout>, tensor<32x32xf32, #src_layout> -> tensor<64x64xf32, #dst_layout> + tt.return + } +} + +// ----- + +#src_layout = #ttg.linear<{register=[[1]], lane=[[2], [4], [8], [16], [32], [64]], warp=[[128]], block=[]}> +#dst_layout = #ttg.linear<{register=[[1], [256], [512]], lane=[[2], [4], [8], [16], [32], [64]], warp=[[128]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @concat_ll_1d( + %arg0: tensor<256xf32, #src_layout>, + %arg1: tensor<256xf32, #src_layout>, + %arg2: tensor<256xf32, #src_layout>, + %arg3: tensor<256xf32, #src_layout>){ + // CHECK: llvm.func @concat_ll_1d + + // CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct + // CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct + // CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct + // CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct + + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<256xf32, #src_layout>, tensor<256xf32, #src_layout>, tensor<256xf32, #src_layout>, tensor<256xf32, #src_layout> -> tensor<1024xf32, #dst_layout> + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-optimize-epilogue.mlir b/test/TritonGPU/amd/amd-optimize-epilogue.mlir index 9c0d91881f2f..e84485fcd5bc 100644 --- a/test/TritonGPU/amd/amd-optimize-epilogue.mlir +++ b/test/TritonGPU/amd/amd-optimize-epilogue.mlir @@ -43,7 +43,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} // ----- // CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[32, 0], [64, 0]], block = []}> -// CHECK-LABEL: store_dword +// CHECK-LABEL: store_dword_128x128 // CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked> // CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128x!tt.ptr, #mma> -> tensor<128x128x!tt.ptr, #linear> // CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #linear> @@ -51,7 +51,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> #mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @store_dword(%arg0: !tt.ptr) attributes {noinline = false} { + tt.func public @store_dword_128x128(%arg0: !tt.ptr) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %cst_0 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %cst_1 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> @@ -63,3 +63,26 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} tt.return } } + +// ----- +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 128], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 32], [0, 64], [32, 0]], block = []}> +// CHECK-LABEL: store_dword_256x256 +// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> +// CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256x!tt.ptr, #mma> -> tensor<256x256x!tt.ptr, #linear> +// CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256xf16, #mma> -> tensor<256x256xf16, #linear> +// CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<256x256x!tt.ptr, #linear> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @store_dword_256x256(%arg0: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<256x256xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<256x256xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + %2 = arith.truncf %1 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked> + tt.store %3, %2 : tensor<256x256x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index cfdc5e9f0134..07f4dfd75ee7 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -236,7 +236,18 @@ def make_ttgir(mod, metadata, options): if options.schedule_hint == "local-prefetch": global_prefetch = local_prefetch = 1 + # passes.ttgpuir.add_pipeline(pm, options.num_stages, False) amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy) + + if False: + pm.run(mod) + with open("mid.mlir", "w") as f: + print(mod, file=f) + context = mod.context + mod = ir.parse_mlir_module("mod.mlir", context) + mod.context = context + pm = ir.pass_manager(mod.context) + if use_async_copy: amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch) passes.common.add_canonicalizer(pm) @@ -250,8 +261,9 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_reorder_instructions(pm) use_block_pingpong = is_pingpong_schedule_enabled(options.arch) - if use_block_pingpong and options.num_stages == 2: - amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages) + if use_block_pingpong and options.num_stages in [2, 4]: + amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages, use_async_copy) + passes.ttgpuir.add_remove_layout_conversions(pm) if knobs.amd.use_buffer_ops: amd.passes.ttgpuir.add_canonicalize_pointers(pm) @@ -264,6 +276,7 @@ def make_ttgir(mod, metadata, options): passes.common.add_symbol_dce(pm) if use_async_copy: amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch) + passes.ttgpuir.add_remove_layout_conversions(pm) pm.run(mod) return mod @@ -397,6 +410,11 @@ def make_amdgcn(src, metadata, options): if knobs.amd.dump_amdgcn: print("// -----// AMDGCN Dump //----- //") print(amdgcn) + # if amdgcn.find("_attn_fwd") + # with open("out.amdgcn", "r") as f: + # amdgcn = f.read() + with open("out.amdgcn", "w") as f: + f.write(amdgcn) return amdgcn @staticmethod diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 17d9409468d8..b487c1402332 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -119,6 +119,75 @@ def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> { let hasVerifier = 1; } +def ConcatOp : TT_AMDGPU_Op<"concat", [Pure]> { + let summary = "concat operation"; + let description = [{ + The "concat" operation combines a list of source n-dimensional tensors into a single larger destination tensor. + + All source tensors must have the same shape, element type, and encoding. + The concatenation dimension is inferred from the source and destination shapes provided by the user. + For example, two tensors of shape 64x128 can produce a destination shape of 128x128, + indicating concatenation along dimension 0; or 64x256, indicating concatenation along dimension 1. + + Generally, source tensors passed as op arguments can be arranged into the resulting shape in multiple ways. + For example, given four tensors of shape 64x64: + concat s0<64x64>, s1<64x64>, s2<64x64>, s3<64x64> -> <128x128> + + They can be laid out in different configurations within the result tensor: + 1) s0 s1 2) s0 s2 + s2 s3 s1 s3 + + From a logical tensor perspective, the source tensors are treated as elements of a tensor of tensors. + In other words, the 1-D array of input tensors is conceptually reshaped into an n-D grid. + The semantics of this op assume a row-major order (or its n-D generalization), + meaning the fastest-varying dimension is filled first, and the slowest-varying dimension is filled last. + In the example above, this corresponds to layout 1). + + The source and destination tensors must have identical linear layouts at the CTA tile level. + That is, all base vectors for input dimensions must match, except for the register input dimension. + The register basis must align on the subset that defines the logical tensor shape of a single CTA tile. + + This ensures that the concatenation is a no-op, meaning no data rearrangement among threads is required + to assemble the destination tensor with the given shape and layout. + However, the order of CTA tiles within the layout does not need to match between source and destination layouts. + It is the responsibility of the op's lowering logic to handle this correctly. + + This op is designed to work on logical tensors directly, avoiding the need for complex layout reinterpretation or reshaping. + For example, the `tt.join` operation only supports concatenation along the innermost dimension, + and requires that the resulting innermost dimension provide 2 elements per thread, distributed across registers. + In contrast, this `concat` op imposes no constraints on the concatenation dimension or the size of dimensions. + + * sources: a list of the input tensors. + + Example 1: + + ```mlir + #blocked = #ttg.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> + %0 = amdgpu.concat %arg0, %arg1: tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, + -> tensor<64x64xf32, #blocked> + ``` + + Example 2: + ```mlir + #src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> + #dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> + %0 = amdgpu.concat %arg0, %arg1, %arg2, %arg3 : tensor<128x128xf16, #src_layout>, tensor<128x128xf16, #src_layout>, tensor<128x128xf16, #src_layout>, + tensor<128x128xf16, #src_layout> -> tensor<256x256xf16, #dst_layout> + ``` + + }]; + + let arguments = (ins Variadic:$sources); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $sources attr-dict `:` type($sources) `->` type($result) + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // InstructionSchedHint //===----------------------------------------------------------------------===// diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h b/third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h new file mode 100644 index 000000000000..82063e528ce6 --- /dev/null +++ b/third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h @@ -0,0 +1,12 @@ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_ASYNCUTILITY_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_ASYNCUTILITY_H_ + +#include "mlir/IR/Value.h" + +namespace mlir::triton::AMD { +// Traverses the def-chain including control flow of the token and returns true +// if all defining operations are an AsyncWait +bool comesFromAsyncWait(mlir::Value value); +} // namespace mlir::triton::AMD + +#endif diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h index 6763de2eba22..724849f01bbf 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -11,6 +11,9 @@ void populateExtractSliceOpToLLVMPatterns( void populateInThreadTransposeOpToTTGPatterns(mlir::RewritePatternSet &patterns, mlir::PatternBenefit benefit); +void populateConcatOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit); } // namespace mlir::triton::AMD diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index fccb65d061ab..9d48e1ffe208 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -34,7 +34,8 @@ std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass( std::string archGenName = std::string()); std::unique_ptr -createTritonAMDGPUBlockPingpongPass(int32_t numStages = 2); +createTritonAMDGPUBlockPingpongPass(int32_t numStages = 2, + bool useAsyncCopy = false); std::unique_ptr createTritonAMDGPUInThreadTransposePass(); diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 91bd40000222..9f9bf9cf7b0e 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -168,11 +168,12 @@ def TritonAMDGPUBlockPingpong: Pass<"tritonamdgpu-block-pingpong", "mlir::Module let dependentDialects = ["mlir::ROCDL::ROCDLDialect, mlir::triton::amdgpu::TritonAMDGPUDialect"]; - let options = [ - Option<"numStages", "num-stages", - "int32_t", /*default*/"2", - "Number of Pipeline stages">, - ]; + let options = + [Option<"numStages", "num-stages", "int32_t", /*default*/ "2", + "Number of Pipeline stages">, + Option<"useAsyncCopy", "use_async_copy", "bool", /*default*/ "false", + "Use AsyncCopyGlobalToLocal to directly load to shared memory">, + ]; } def TritonAMDGPUInThreadTranspose: Pass<"tritonamdgpu-in-thread-transpose", "mlir::triton::FuncOp"> { diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 7543805fc084..586ebfda9dc1 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -302,4 +302,96 @@ InThreadTransposeOp::deduceOutputLayout(ArrayRef shape, return transposedLL; } +LogicalResult ConcatOp::verify() { + auto sources = getSources(); + auto result = getResult(); + + auto srcType = cast(sources.front().getType()); + auto dstType = cast(result.getType()); + + auto srcShape = srcType.getShape(); + auto dstShape = dstType.getShape(); + unsigned rank = srcShape.size(); + + // 1) Shape related checks. + if (rank != dstShape.size()) + return emitError() + << "Source and destination tensors must have the same rank."; + + unsigned numTiles = 1; + for (int i = 0; i < rank; ++i) { + if (dstShape[i] % srcShape[i] != 0) + return emitError() << "Source and destination tensor shapes don't match."; + numTiles *= dstShape[i] / srcShape[i]; + } + + if (numTiles != sources.size()) + return emitError() << "Number of source tiles (" << sources.size() + << ") doesn't match required count (" << numTiles + << ")."; + + // 2) Check that all sources have same type and element type match. + for (auto src : sources) { + auto curr = dyn_cast(src.getType()); + if (curr != srcType) + return emitError() << "All sources must have identical tensor types."; + } + + if (dstType.getElementType() != srcType.getElementType()) + return emitError() + << "Element types of sources and destination must match."; + + // 3) Verify that source and destination layout match on a CTA tile. + auto srcLL = triton::gpu::toLinearLayout(srcShape, srcType.getEncoding()); + auto dstLL = triton::gpu::toLinearLayout(dstShape, dstType.getEncoding()); + + auto getBases = [&](StringRef name) { + auto key = StringAttr::get(getContext(), name); + return std::pair{srcLL.getBases().lookup(key), + dstLL.getBases().lookup(key)}; + }; + + auto [regSrc, regDst] = getBases("register"); + auto [laneSrc, laneDst] = getBases("lane"); + auto [warpSrc, warpDst] = getBases("warp"); + + auto shapeCTASrc = mlir::triton::gpu::getShapePerCTATile(srcType); + auto shapeCTADst = mlir::triton::gpu::getShapePerCTATile(dstType); + if (shapeCTASrc != shapeCTADst) + return emitError() << "CTA tile shapes must match between source and " + "destination tensors."; + + unsigned numCTAs = 1; + for (int d = 0; d < rank; ++d) + numCTAs *= srcShape[d] / shapeCTASrc[d]; + unsigned elemsPerThread = + triton::gpu::getTotalElemsPerThread(srcType) / numCTAs; + unsigned regCompareLen = std::log2(elemsPerThread); + + auto compareBasis = [&](auto &srcBasis, auto &dstBasis, StringRef message, + int limit = -1) { + int n = (limit < 0 ? srcBasis.size() + : std::min(srcBasis.size(), limit)); + for (size_t i = 0; i < n; ++i) { + if (srcBasis[i] != dstBasis[i]) { + emitError() << message; + return false; + } + } + return true; + }; + + if (laneSrc != laneDst || warpSrc != warpDst) { + return emitError() << "Lane and warp dim basis must match between source " + "and destination layout."; + } + + if (!compareBasis(regSrc, regDst, + "Register basis must match on a CTA tile between source " + "and destination.", + regCompareLen)) + return failure(); + + return success(); +} } // namespace mlir::triton::amdgpu diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt index 693bd41bc55a..35310b86eecd 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonAMDGPUDialectToLLVM TritonAMDGPUToLLVMPatterns.cpp ExtractSliceOpToLLVM.cpp InThreadTransposeOpToTTG.cpp + ConcatOpToLLVM.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp new file mode 100644 index 000000000000..9d75b3b7d204 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp @@ -0,0 +1,171 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +template +std::vector multiDimElementwise(ArrayRef lhs, ArrayRef rhs, + BinaryOp op) { + assert(lhs.size() == rhs.size() && "Input dimensions must match"); + std::vector result; + result.reserve(lhs.size()); + for (size_t i = 0, n = lhs.size(); i < n; ++i) { + unsigned a = static_cast(lhs[i]); + unsigned b = static_cast(rhs[i]); + result.push_back(op(a, b)); + } + return result; +} + +template unsigned getNumElements(const ArrayRef shape) { + return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); +} + +// Determine the order in which CTA tiles are laid out across the tensor. +// That is, create vector of dimensions from fastest to slowest varying. +SmallVector getCTATileOrder(MLIRContext *ctx, + const LinearLayout &layout) { + auto llEnc = triton::gpu::LinearEncodingAttr::get(ctx, layout); + auto regDim = StringAttr::get(ctx, "register"); + auto &bases = layout.getBases().find(regDim)->second; + + // Compute number of CTA tiles in a layout. + unsigned totalElems = layout.getTotalOutDimSize(); + auto ctaShape = llEnc.getShapePerCTATile(); + unsigned elemsPerCTA = + std::accumulate(ctaShape.begin(), ctaShape.end(), 1, std::multiplies<>()); + assert((totalElems % elemsPerCTA) == 0 && + "Total elements must be divisible by elemsPerCTA"); + unsigned numCTAs = totalElems / elemsPerCTA; + + // To determine the CTA tile order, start by identifying the register basis + // vector that corresponds to the first element of the second CTA tile. The + // nonzero index in the logical tensor it maps to indicates the fastest + // varying dimension. Then, for each subsequent basis register (first element + // of some CTA tile), extract the next nonzero index to build the full + // dimension order. + unsigned registersPerThreadPerCTA = + product(llEnc.basesPerDim(regDim, /*skipBroadcast=*/false)) / numCTAs; + unsigned startIndex = + static_cast(std::log2(registersPerThreadPerCTA)); + + llvm::SmallSetVector order; + for (unsigned i = startIndex; i < bases.size(); ++i) { + auto range = llvm::make_range(bases[i].begin(), bases[i].end()); + auto it = llvm::find_if(range, [](unsigned v) { return v != 0; }); + if (it != bases[i].end()) + order.insert(std::distance(bases[i].begin(), it)); + } + + // Append any dims missing from our default order. + for (unsigned dim : llEnc.getOrder()) + order.insert(dim); + + return order.takeVector(); +} + +struct ConcatOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(amdgpu::ConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + RankedTensorType resultType = + cast(op.getResult().getType()); + + ArrayRef dstShape = resultType.getShape(); + Attribute dstEncoding = resultType.getEncoding(); + + Value srcVal = op.getSources()[0]; + RankedTensorType srcType = cast(srcVal.getType()); + ArrayRef srcShape = srcType.getShape(); + Attribute srcEncoding = srcType.getEncoding(); + + MLIRContext *context = resultType.getContext(); + auto linearLayoutSrc = triton::gpu::toLinearLayout(srcShape, srcEncoding); + auto linearLayoutDst = triton::gpu::toLinearLayout(dstShape, dstEncoding); + auto srcCTAOrder = getCTATileOrder(context, linearLayoutSrc); + auto dstCTAOrder = getCTATileOrder(context, linearLayoutSrc); + + auto rank = srcShape.size(); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(resultType); + auto sources = adaptor.getSources(); + + unsigned totalElems = ::getNumElements(dstShape); + unsigned elemsPerTile = ::getNumElements(shapePerCTATile); + unsigned numCTATiles = totalElems / elemsPerTile; + + // Default order is fastest to slowest varying dimension. + std::vector defaultOrder(rank); + std::iota(defaultOrder.rbegin(), defaultOrder.rend(), 0); + + auto dstCTAShape = multiDimElementwise( + dstShape, shapePerCTATile, std::divides()); + auto srcCTAShape = multiDimElementwise( + srcShape, shapePerCTATile, std::divides()); + auto srcToDstShape = multiDimElementwise( + dstShape, srcShape, std::divides()); + + unsigned elemsPerThreadPerCTA = + triton::gpu::getTotalElemsPerThread(srcType) / + ::getNumElements(srcCTAShape); + + llvm::SmallVector resultVals; + llvm::SmallVector> unpackedSources; + unpackedSources.reserve(sources.size()); + + for (size_t i = 0; i < sources.size(); i++) { + Value currSrc = sources[i]; + unpackedSources.push_back(unpackLLElements(loc, currSrc, rewriter)); + } + + // Traverse CTA tiles in the result tensor + for (int i = 0; i < numCTATiles; ++i) { + auto currTileIdx = mlir::LLVM::delinearize(i, dstCTAShape, dstCTAOrder); + // The n-dim destination tensor is built by arranging n-dim source tensors + // into a destination tensor shape. Determine which source tensor contains + // the current CTA tile. + auto multiDimSrcIdx = multiDimElementwise( + currTileIdx, srcCTAShape, std::divides()); + // Compute linear index of the current source tensor. + // Concat operands are laid out in the destination tensor + // in fastest slowest varying dimension order. + auto linearSrcIdx = + mlir::LLVM::linearize(multiDimSrcIdx, srcToDstShape, defaultOrder); + + // After determining which source tensor the current CTA tile belongs to, + // compute the index of this CTA tile within that source tensor, + // considering the source tensors may include CTA tiles. + auto multiDimSrcCTAIdx = multiDimElementwise( + currTileIdx, srcCTAShape, std::modulus()); + auto linearSrcCTAIdx = + mlir::LLVM::linearize(multiDimSrcCTAIdx, srcCTAShape, srcCTAOrder); + auto unpackedElements = unpackedSources[linearSrcIdx]; + + auto startIt = + unpackedElements.begin() + linearSrcCTAIdx * elemsPerThreadPerCTA; + auto endIt = startIt + elemsPerThreadPerCTA; + llvm::append_range(resultVals, llvm::make_range(startIt, endIt)); + } + + Value packedResult = packLLElements(loc, this->getTypeConverter(), + resultVals, rewriter, resultType); + + rewriter.replaceOp(op, packedResult); + return success(); + } +}; +} // namespace + +namespace mlir::triton::AMD { +void populateConcatOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index 07cf91870fed..ed915577bf85 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -1,3 +1,4 @@ +#include "../TritonAMDGPUToLLVM/Utility.h" #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -49,6 +50,7 @@ using namespace mlir::triton; // clang-format on namespace { + struct ExtractSliceOpConversion : public ConvertOpToLLVMPattern { explicit ExtractSliceOpConversion(LLVMTypeConverter &typeConverter, @@ -60,61 +62,61 @@ struct ExtractSliceOpConversion ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); auto srcTy = cast(op.getSource().getType()); - auto srcLayout = srcTy.getEncoding(); + auto dstTy = cast(op.getType()); auto srcShape = srcTy.getShape(); - auto resultTy = cast(op.getType()); - auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); - auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); - auto contigPerThread = triton::gpu::getContigPerThread(srcTy); - auto totalContigPerThread = product(contigPerThread); - auto order = triton::gpu::getOrder(srcTy); + auto dstShape = dstTy.getShape(); - // Calculate valid total number of workers in each dimension + auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcTy); - shapePerCTATile[0] = - std::min(static_cast(srcShape[0]), shapePerCTATile[0]); - shapePerCTATile[1] = - std::min(static_cast(srcShape[1]), shapePerCTATile[1]); - - // Rank == 2 checked in the verifier - SmallVector sizes; - for (auto i = 0; i < 2; ++i) { - sizes.push_back(resultTy.getDimSize(i)); - } + auto srcCTAShape = LLVM::AMD::multiDimElementwise( + srcShape, shapePerCTATile, std::divides()); + auto dstCTAShape = LLVM::AMD::multiDimElementwise( + dstShape, shapePerCTATile, std::divides()); + auto numCTATiles = std::accumulate(dstCTAShape.begin(), dstCTAShape.end(), + 1, std::multiplies<>()); auto offsets = op.getStaticOffsets(); + auto firstTileCoordinate = + LLVM::AMD::multiDimElementwise( + offsets, shapePerCTATile, std::divides()); - // Calculate offsets and sizes in terms of CTA units. - std::array CTAOffsets{offsets[0] / shapePerCTATile[0], - offsets[1] / shapePerCTATile[1]}; - std::array CTASizes{sizes[0] / shapePerCTATile[0], - sizes[1] / shapePerCTATile[1]}; - std::array CTAPerShape{srcShape[0] / shapePerCTATile[0], - srcShape[1] / shapePerCTATile[1]}; - - // The diagram above illustrates the graphical representation of the - // skipElems, tensorStride, and lastIdx variables. - auto skipElems = CTAOffsets[order[1]] * (elemsPerThread[order[0]] * - contigPerThread[order[1]]) + - CTAOffsets[order[0]] * totalContigPerThread; - auto tensorStride = - (CTAPerShape[order[0]] - CTASizes[order[0]]) * totalContigPerThread; - auto lastIdx = - (CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * - elemsPerThread[order[0]] * contigPerThread[order[1]] + - (CTAOffsets[order[0]] + CTASizes[order[0]]) * totalContigPerThread; - - assert(lastIdx <= vals.size()); + Attribute srcEncoding = srcTy.getEncoding(); + Attribute dstEncoding = dstTy.getEncoding(); + auto linearLayoutSrc = triton::gpu::toLinearLayout(srcShape, srcEncoding); + auto linearLayoutDst = triton::gpu::toLinearLayout(dstShape, dstEncoding); + auto srcCTAOrder = + LLVM::AMD::getCTATileOrder(srcTy.getContext(), linearLayoutSrc); + auto dstCTAOrder = + LLVM::AMD::getCTATileOrder(srcTy.getContext(), linearLayoutDst); + + unsigned elemsPerThreadPerCTA = + triton::gpu::getTotalElemsPerThread(srcTy) / + std::accumulate(srcCTAShape.begin(), srcCTAShape.end(), 1, + std::multiplies<>()); + + // 1. Process CTA tiles in the destination tensor according to the + // destination's linear layout order of CTA tiles. + // 2. For each tile position in the destination tensor, compute its + // corresponding position in the source tensor. + // 3. Copy the values from the source tile to the destination slice. SmallVector resultVals; - for (int i = skipElems; i < lastIdx; i += tensorStride) { - for (int j = 0; j < totalContigPerThread * CTASizes[order[0]]; ++j, ++i) { - assert(i < lastIdx); - resultVals.push_back(vals[i]); + for (size_t i = 0; i < numCTATiles; i++) { + auto coordInDstTensor = + mlir::LLVM::delinearize(i, dstCTAShape, dstCTAOrder); + auto coordInSrcTensor = + LLVM::AMD::multiDimElementwise( + coordInDstTensor, firstTileCoordinate, std::plus()); + auto linearIdxInSrcTensor = + mlir::LLVM::linearize(coordInSrcTensor, srcCTAShape, srcCTAOrder); + + for (size_t j = 0; j < elemsPerThreadPerCTA; j++) { + resultVals.push_back( + vals[linearIdxInSrcTensor * elemsPerThreadPerCTA + j]); } } Value ret = packLLElements(loc, this->getTypeConverter(), resultVals, - rewriter, resultTy); + rewriter, dstTy); rewriter.replaceOp(op, ret); return success(); @@ -124,11 +126,7 @@ struct ExtractSliceOpConversion matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcTy = op.getSource().getType(); - if (isa( - op.getSource().getType().getEncoding())) { - return processLayout(op, adaptor, rewriter); - } - return failure(); + return processLayout(op, adaptor, rewriter); } }; } // namespace diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp index a84d84b2819d..c0cf0fb5fefa 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -7,5 +7,6 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, PatternBenefit benefit) { populateExtractSliceOpToLLVMPatterns(typeConverter, patterns, benefit); populateInThreadTransposeOpToTTGPatterns(patterns, benefit); + populateConcatOpToLLVMPatterns(typeConverter, patterns, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp new file mode 100644 index 000000000000..484422021894 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp @@ -0,0 +1,62 @@ +#include "third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h" +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/Operation.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::AMD { + +// Traverses the def-chain including control flow of the token and returns true +// if all defining operations are an AsyncWait +bool comesFromAsyncWait(mlir::Value token) { + if (auto defOp = token.getDefiningOp()) { + if (isa(defOp)) + return true; + else if (auto castOp = dyn_cast(defOp)) + return comesFromAsyncWait(castOp.getInputs()[0]); + else + return false; + } + + auto blockArg = llvm::dyn_cast(token); + // If the token has no defining op and is not an BlockArgument bail out + if (!blockArg) { + return false; + } + + auto block = blockArg.getOwner(); + auto argId = blockArg.getArgNumber(); + + auto destOperandFromAsyncWait = [argId](auto &&operands) { + assert(argId < operands.size()); + return comesFromAsyncWait(operands[argId]); + }; + + // Check all predecessor block's terminator and follow the passed value at + // argId to see if they are immediately an AsyncWait. + for (auto *pred : block->getPredecessors()) { + auto terminator = pred->getTerminator(); + if (auto br = llvm::dyn_cast(terminator)) { + if (!destOperandFromAsyncWait(br.getDestOperands())) + return false; + } else if (auto condBr = llvm::dyn_cast(terminator)) { + if (condBr.getTrueDest() == block) { + if (!destOperandFromAsyncWait(condBr.getTrueDestOperands())) + return false; + } + if (condBr.getFalseDest() == block) { + if (!destOperandFromAsyncWait(condBr.getFalseDestOperands())) + return false; + } + } else if (auto br = llvm::dyn_cast(terminator)) { + if (!destOperandFromAsyncWait(br.getDestOperands())) + return false; + } else { + llvm::dbgs() << "no terminator!" << *terminator << "\n"; + return false; + } + } + return true; +} + +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp index 7fe495ff3dd5..9cec2cd8b51d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -204,7 +204,7 @@ struct ConvertBuiltinFuncToLLVM ModuleOp mod = getOperation(); GreedyRewriteConfig config; - config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive); + config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; RewritePatternSet patterns(context); patterns.add(context, this->ftz); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index 2842cc76bfc4..4a3db488a21f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonAMDGPUToLLVM + AsyncUtility.cpp AtomicRMWOpsEmitter.cpp BufferOpsEmitter.cpp ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 7af92231b4a1..1d2e2e039491 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -189,21 +189,10 @@ static bool matchMFMAAndLinearLayoutCase(RankedTensorType srcTy, if (!mfmaLayout || !linearLayout) return false; - std::optional srcLL = + std::optional storeLL = mlir::triton::gpu::chooseMfmaLikeStoreLayout(srcTy); - if (!srcLL) - return false; - - MLIRContext *ctx = linearLayout.getContext(); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kRegister = StringAttr::get(ctx, "register"); - auto srcBase = srcLL.value().getBases(); - auto srcReg = srcBase.lookup(kRegister); - auto srcLane = srcBase.lookup(kLane); - auto dstBases = linearLayout.getLinearLayout().getBases(); - auto dstReg = dstBases.lookup(kRegister); - auto dstLane = dstBases.lookup(kLane); - return dstReg == srcReg && dstLane == srcLane; + return linearLayout.getLinearLayout() == + storeLL.value_or(LinearLayout::empty()); }; struct ConvertLayoutOpMFMAToLinearConversion diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index fa9c1f48d72a..96e0cc00c629 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; @@ -238,11 +239,119 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { } } - // Emits the computation to get the lane index which holds the source + SmallVector emitSharedBaseAddr(RewriterBase &rewriter, Operation *op, + RankedTensorType srcTy, + MemDescType dstTy, bool hasSwizzling, + Type resElemTy, Value llDst, + VectorType &vecTy) const { + auto emitSharedAddresses = [&](MemDescType dstTy, + SmallVector &shmemAddrs, + VectorType &vecTy, bool forceLane0) { + auto loc = op->getLoc(); + auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( + loc, llDst, resElemTy, rewriter); + bool ok = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { + vecTy = vecTy_; + shmemAddrs.push_back(shmemAddr); + }, + forceLane0); + assert(ok); + }; + + if (hasSwizzling) { + // Rewrite dstTy to be coalesced + auto dstEnc = cast(dstTy.getEncoding()); + auto flatSharedEnc = SwizzledSharedEncodingAttr::get( + op->getContext(), dstEnc.getVec(), 1, 1, dstEnc.getOrder(), + dstEnc.getCTALayout()); + dstTy = MemDescType::get(dstTy.getShape(), dstTy.getElementType(), + flatSharedEnc, dstTy.getMemorySpace()); + } + SmallVector ldsAddrs; + emitSharedAddresses(dstTy, ldsAddrs, vecTy, true); + return ldsAddrs; + } + + SmallVector emitSwizzleOffsets(Operation *op, RewriterBase &rewriter, + RankedTensorType srcTy, + MemDescType dstTy, VectorType vecTy, + int numberOfLoads) const { + auto loc = op->getLoc(); + TritonLLVMOpBuilder b(loc, rewriter); + + // Compute swizzle offsets + auto regLayout = + triton::gpu::toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + auto shape = dstTy.getShape(); + LinearLayout sharedLayout = + triton::gpu::toLinearLayout(shape, dstTy.getEncoding()); + LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + + auto dstEnc = cast(dstTy.getEncoding()); + auto flatSharedEnc = SwizzledSharedEncodingAttr::get( + srcTy.getContext(), dstEnc.getVec(), 1, 1, dstEnc.getOrder(), + dstEnc.getCTALayout()); + auto flatDst = MemDescType::get(dstTy.getShape(), dstTy.getElementType(), + flatSharedEnc, dstTy.getMemorySpace()); + + auto regToSharedFlat = regLayout.invertAndCompose( + triton::gpu::toLinearLayout(shape, flatDst.getEncoding())); + // llvm::outs() << "Flat: " << regToSharedFlat << "\n"; + + MLIRContext *ctx = rewriter.getContext(); + StringAttr kBlock = str_attr("block"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + SmallVector swizzleOffsets; + for (int i = 0; i < numberOfLoads; i++) { + auto regId = b.i32_val(i * vecTy.getNumElements()); + + // for (int l = 0; l < 64; l++) { + // SmallVector> input = { + // {kRegister, i * vecTy.getNumElements()}, + // {kLane, l}, + // {kWarp, 0}, + // {kBlock, 0}}; + + // auto swizzOff = regToSharedLayout.apply(input)[0].second; + // auto flatOff = regToSharedFlat.apply(input)[0].second; + + // auto laneOff = (swizzOff - flatOff) / vecTy.getNumElements(); + + // llvm::outs() << l << ": " << swizzOff << ", " << flatOff << " = " + // << laneOff << "\n"; + // } + + auto swizzleOffset = + llvm::to_vector(llvm::drop_end(llvm::make_second_range( + applyLinearLayout(loc, rewriter, regToSharedLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}}))))[0]; + auto flatOffset = llvm::to_vector(llvm::drop_end(llvm::make_second_range( + applyLinearLayout(loc, rewriter, regToSharedFlat, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}}))))[0]; + auto laneOffet = b.sdiv(b.sub(swizzleOffset, flatOffset), + b.i32_val(vecTy.getNumElements())); + swizzleOffsets.push_back(laneOffet); + } + + return swizzleOffsets; + } + + // Emits the computation to get the lane id offset which holds the source // pointers/offsets we need to store to shared memory - Value emitSwizzledLaneIndex(RewriterBase &rewriter, TritonLLVMOpBuilder &b, - Location loc, Value coalescedShmem, - Value swizzledShmem, Value vecBytes) const { + Value emitSwizzledLaneOffset(RewriterBase &rewriter, TritonLLVMOpBuilder &b, + Location loc, Value coalescedShmem, + Value swizzledShmem, Value vecBytes) const { // Compute the laneOffset based on the difference in elements between // the two shmem addresses. laneOffset will be negative for half the // lanes because a smaller laneId might hold our global_ptr. @@ -250,9 +359,7 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { auto swizzledAddr = b.ptrtoint(i64_ty, swizzledShmem); auto diff = b.trunc(i32_ty, b.sub(swizzledAddr, coalescedAddr)); Value laneOffset = b.sdiv(diff, vecBytes); - // laneId + laneOffset will always stay inside the warp [0, - // threadsPerWarp) because we only swizzle inside a warp - return b.add(getLaneId(rewriter, loc), laneOffset); + return laneOffset; } // Swizzle the mask (1bit) based on selectLane via ballot @@ -525,6 +632,11 @@ struct BufferLoadToLocalOpConversion llDst, coalescedShmemAddr, swizzledShmemAddr, vecTy); assert(vecTy.getNumElements() == vec); + auto ldsBaseAddresses = emitSharedBaseAddr( + rewriter, op, ptrType, dstTy, hasSwizzling, resElemTy, llDst, vecTy); + auto swizzleOffsets = emitSwizzleOffsets(op, rewriter, ptrType, dstTy, + vecTy, ldsBaseAddresses.size()); + int vecBytes = (vecTy.getNumElements() * vecTy.getElementTypeBitWidth()) / 8; assert(llvm::isPowerOf2_32(vecBytes)); @@ -534,18 +646,38 @@ struct BufferLoadToLocalOpConversion // based on the collected shared addresses and vector size Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride); - for (int i = 0; i < coalescedShmemAddr.size(); i++) { + bool useFastSwizzling = tools::getBoolEnv("TRITON_HIP_ASYNC_FAST_SWIZZLE"); + + for (int i = 0; i < ldsBaseAddresses.size(); i++) { auto srcIdx = i * vec; auto offsetIn = offsetElems[srcIdx]; + auto ldsDst = + useFastSwizzling ? ldsBaseAddresses[i] : coalescedShmemAddr[i]; + Value pred = mask ? maskElems[srcIdx] : b.true_val(); if (hasSwizzling) { // Apply swizzling to the src offsets - Value swizzledLaneId = - emitSwizzledLaneIndex(rewriter, b, loc, coalescedShmemAddr[i], - swizzledShmemAddr[i], vecBytesVal); - offsetIn = - targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId); + Value laneOffset = + emitSwizzledLaneOffset(rewriter, b, loc, coalescedShmemAddr[i], + swizzledShmemAddr[i], vecBytesVal); + + if (useFastSwizzling) { + laneOffset = swizzleOffsets[i]; + } + + // laneId + laneOffset will always stay inside the warp [0, + // threadsPerWarp) because we only swizzle inside a warp + Value swizzledLaneId = b.add(getLaneId(rewriter, loc), laneOffset); + + if (tools::getBoolEnv("TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE")) { + offsetIn = b.add( + offsetIn, b.mul(laneOffset, b.i32_val(vecTy.getNumElements()))); + } else { + offsetIn = + targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId); + } + if (mask) { pred = shuffleMask(rewriter, b, loc, targetInfo, swizzledLaneId, pred); @@ -553,8 +685,7 @@ struct BufferLoadToLocalOpConversion } auto bufferLoadToLds = bufferEmitter.emitLoadToLds( - vecTy, vecBytesVal, rsrcDesc, offsetIn, coalescedShmemAddr[i], pred, - op.getCache()); + vecTy, vecBytesVal, rsrcDesc, offsetIn, ldsDst, pred, op.getCache()); LLVM::AMD::addAsyncCopyAliasScope(bufferLoadToLds); if (!otherElems.empty()) { Value storeVal = packElementRangeIntoVector( @@ -666,9 +797,12 @@ struct AsyncCopyGlobalToLocalOpConversion if (hasSwizzling) { // Apply swizzling to the src pointers - Value swizzledLaneId = - emitSwizzledLaneIndex(rewriter, b, loc, coalescedShmemAddr[i], - swizzledShmemAddr[i], vecBytesVal); + Value laneOffset = + emitSwizzledLaneOffset(rewriter, b, loc, coalescedShmemAddr[i], + swizzledShmemAddr[i], vecBytesVal); + // laneId + laneOffset will always stay inside the warp [0, + // threadsPerWarp) because we only swizzle inside a warp + Value swizzledLaneId = b.add(getLaneId(rewriter, loc), laneOffset); srcPtr = targetInfo.shuffleIdx(rewriter, loc, srcPtr, swizzledLaneId); if (!maskElements.empty()) { pred = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp index 32c9f4c4c730..26673d320a33 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp @@ -1,55 +1,12 @@ #include "third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h" #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" namespace mlir::triton::AMD { namespace { -// Traverses the def-chain including control flow of the token and returns true -// if all defining operations are an AsyncWait -bool comesFromAsyncWait(Value token) { - if (auto defOp = token.getDefiningOp()) { - return isa(defOp); - } - - auto blockArg = dyn_cast(token); - // If the token has no defining op and is not an BlockArgument bail out - if (!blockArg) { - return false; - } - - auto block = blockArg.getOwner(); - auto argId = blockArg.getArgNumber(); - - auto destOperandFromAsyncWait = [argId](auto &&operands) { - assert(argId < operands.size()); - return comesFromAsyncWait(operands[argId]); - }; - - // Check all predecessor block's terminator and follow the passed value at - // argId to see if they are immediately an AsyncWait. - for (auto *pred : block->getPredecessors()) { - auto terminator = pred->getTerminator(); - if (auto br = dyn_cast(terminator)) { - if (!destOperandFromAsyncWait(br.getDestOperands())) - return false; - } else if (auto condBr = dyn_cast(terminator)) { - if (condBr.getTrueDest() == block) { - if (!destOperandFromAsyncWait(condBr.getTrueDestOperands())) - return false; - } - if (condBr.getFalseDest() == block) { - if (!destOperandFromAsyncWait(condBr.getFalseDestOperands())) - return false; - } - } else { - return false; - } - } - return true; -} - // Returns true if one of the operands is a LocalLoad synced via AsyncWait. bool filterAsyncLocalLoadsDeppendencies(Operation *op1, Operation *op2) { auto isAsyncLoad = [](Operation *op) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 1647f4b0680a..02014f732838 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -11,7 +11,6 @@ #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -210,16 +209,9 @@ struct ConvertTritonAMDGPUToLLVM mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); - FailureOr maybeChipset = - mlir::amdgpu::Chipset::parse(this->arch); - if (failed(maybeChipset)) { - emitError(UnknownLoc::get(&getContext()), - "Invalid AMDGPU chipset name: " + this->arch); - return signalPassFailure(); - } // Native lowering patterns - mlir::populateGpuToROCDLConversionPatterns( - typeConverter, patterns, mlir::gpu::amd::HIP, *maybeChipset); + mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, + mlir::gpu::amd::HIP); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 8ec4ff2de468..a78a10bf4f59 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -4,12 +4,14 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/PatternMatch.h" +#include "third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" namespace tt = mlir::triton; using mlir::triton::ModuleAxisInfoAnalysis; +using mlir::triton::AMD::comesFromAsyncWait; using mlir::triton::AMD::DppCtrl; using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; @@ -734,8 +736,9 @@ void addAsyncCopyAliasScope(AliasAnalysisOpInterface directToLdsOp) { void addLocalLoadNoAliasScope(triton::gpu::LocalLoadOp localLoadOp, AliasAnalysisOpInterface llLoadOp) { auto token = localLoadOp.getToken(); - if (!token || !token.getDefiningOp()) + if (!token || !comesFromAsyncWait(token)) { return; + } return addLocalLoadNoAliasScope(llLoadOp); } @@ -752,4 +755,43 @@ void addLocalLoadNoAliasScope(AliasAnalysisOpInterface llLoadOp) { llLoadOp.setAliasScopes(aliasScopes); } +SmallVector getCTATileOrder(MLIRContext *ctx, + const LinearLayout &layout) { + auto llEnc = triton::gpu::LinearEncodingAttr::get(ctx, layout); + auto regDim = StringAttr::get(ctx, "register"); + auto &bases = layout.getBases().find(regDim)->second; + + // Compute number of CTA tiles in a layout. + unsigned totalElems = layout.getTotalOutDimSize(); + auto ctaShape = llEnc.getShapePerCTATile(); + unsigned elemsPerCTA = + std::accumulate(ctaShape.begin(), ctaShape.end(), 1, std::multiplies<>()); + assert((totalElems % elemsPerCTA) == 0 && + "Total elements must be divisible by elemsPerCTA"); + unsigned numCTAs = totalElems / elemsPerCTA; + + // To determine the CTA tile order, start by identifying the register basis + // vector that corresponds to the first element of the second CTA tile. The + // nonzero index in the logical tensor it maps to indicates the most minor + // dimension. Then, for each subsequent basis register (first element of + // some CTA tile), extract the next nonzero index to build the full dimension + // order. + unsigned totalPerThread = + product(llEnc.basesPerDim(regDim, /*skipBroadcast=*/false)) / numCTAs; + unsigned startIndex = static_cast(std::log2(totalPerThread)); + + llvm::SmallSetVector order; + for (unsigned i = startIndex; i < bases.size(); ++i) { + auto it = std::find_if(bases[i].begin(), bases[i].end(), + [](unsigned v) { return v != 0; }); + if (it != bases[i].end()) + order.insert(std::distance(bases[i].begin(), it)); + } + + // Append any dims missing from our default order. + for (unsigned dim : llEnc.getOrder()) + order.insert(dim); + + return SmallVector(order.begin(), order.end()); +} } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index dda259360c61..f2d1e62f97c6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -137,6 +137,23 @@ void addLocalLoadNoAliasScope(AliasAnalysisOpInterface llLoadOp); // Attaches the "AsyncCopies" alias scope to llLoadDirectToLdsOp void addAsyncCopyAliasScope(AliasAnalysisOpInterface llLoadDirectToLdsOp); +// Determine the order in which CTA tiles are laid out across the tensor. +SmallVector getCTATileOrder(MLIRContext *ctx, + const LinearLayout &layout); + +template +std::vector multiDimElementwise(const ArrayRef &lhs, + const ArrayRef &rhs, BinaryOp op) { + assert(lhs.size() == rhs.size() && "Input dimensions must match"); + std::vector result; + result.reserve(lhs.size()); + for (size_t i = 0, n = lhs.size(); i < n; ++i) { + unsigned a = static_cast(lhs[i]); + unsigned b = static_cast(rhs[i]); + result.push_back(op(a, b)); + } + return result; +} } // namespace mlir::LLVM::AMD #endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 7be03c4e6fda..163080e46be8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -7,6 +7,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -23,6 +24,25 @@ namespace tt = mlir::triton; namespace { +template std::optional getSingleUserOf(Value val) { + auto users = llvm::to_vector(val.getUsers()); + if (users.size() == 1) { + auto targetOp = dyn_cast(users[0]); + if (targetOp != nullptr) + return targetOp; + } + return std::nullopt; +} + +template +std::optional getIndex(T consumer, Value target) { + auto it = llvm::find_if(consumer->getOperands(), + [&](Value v) { return v == target; }); + if (it == consumer->getOperands().end()) + return std::nullopt; + return std::distance(consumer->getOperands().begin(), it); +} + // This pass transforms a for-loop calculating a GEMM. Main purpose of the // transform is improve the efficiency of the GPU dot instruction (mfma) // by interleaving the execution of two warps on each SIMD. Especially it groups @@ -37,12 +57,20 @@ class Pingponger { SmallVector gLoadOps; SmallVector lLoadOps; SmallVector lStoreOps; + SmallVector asyncCopyOps; + SmallVector asyncWaitOps; + SmallVector asyncCommitOps; + DenseSet preservedAsyncCommits; + DenseMap> newAsyncGroups; + DenseMap> asyncTokenReassociation; SmallVector dotOps; + SmallVector dotSOps; SmallVector> subViewOps; SmallVector> loadSliceOps; SmallVector dotSliceOps; SmallVector constOffsets; Operation *lastInsertedOp; + const static inline std::string sliceAttrName = "sliceIdx"; // rocdl.s.setprio will be mapped to `s_setprio` instruction which set the // priority of the warp within a SIMD, determines which warp to occupy the @@ -56,10 +84,13 @@ class Pingponger { int32_t kWidth; int32_t numWarps; int32_t numStages; + bool useAsyncCopy; public: - Pingponger(scf::ForOp forOp, int32_t numWarps, int32_t numStages) - : forOp(forOp), numWarps(numWarps), numStages(numStages) {} + Pingponger(scf::ForOp forOp, int32_t numWarps, int32_t numStages, + bool useAsyncCopy) + : forOp(forOp), numWarps(numWarps), numStages(numStages), + useAsyncCopy(useAsyncCopy) {} void getDotPingponged(); private: @@ -67,12 +98,26 @@ class Pingponger { int64_t sliceWidth); LogicalResult genLocalSlice(OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, - unsigned numSlices, int64_t sliceWidth); + unsigned numSlices, int64_t sliceWidth, + bool needCopySliced); + LogicalResult genLocalSliceScales(OpBuilder &builder, Value v, + Attribute dotEncoding, unsigned opIdx, + unsigned numSlices, int64_t sliceWidth, + bool needCopySliced); LogicalResult sliceDot(OpBuilder &builder, Location loc, tt::DotOp op, unsigned numSlices); + LogicalResult sliceDotScaled(OpBuilder &builder, Location loc, + tt::DotScaledOp op, unsigned numSlices); + LogicalResult genAsyncCopySlices(OpBuilder &builder); + LogicalResult updateForOpSignature(OpBuilder &builder); + LogicalResult adjustRefinedAsyncTokens(OpBuilder &builder); + void transformOnePPClusters(OpBuilder &builder, Location loc); LogicalResult transformFourPPClusters(OpBuilder &builder, Location loc); LogicalResult transformTwoPPClusters(OpBuilder &builder, Location loc); + LogicalResult transformFAv3(OpBuilder &builder, Location loc); + LogicalResult transformFP4(OpBuilder &builder, Location loc); + LogicalResult transformFP4s(OpBuilder &builder, Location loc); void addAsymmetricSyncToLoop(OpBuilder &builder, Location loc); void updateOpInsertion(Operation *Op); void appendOp(Operation *Op); @@ -92,8 +137,28 @@ class Pingponger { DenseSet &dotGlobalLoads, DenseSet &dotLocalLoads, DenseSet &dotLocalStores); + LogicalResult pruneDotMemoryOps(DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotLocalStores, + bool assumeNotTaken); + void determineDotAsyncMemoryOps( + tt::DotOp dotOp, + DenseSet &dotAsyncGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotAsyncCommitGroups, + DenseSet &dotAsyncWaits); + LogicalResult pruneDotAsyncMemoryOps( + DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotAsyncCommitGroups, + DenseSet &dotAsyncWaits, bool assumeNotTaken); template void findClosestPredOps(Value v, DenseSet &matchingOps); + + LogicalResult genLocalSliceHelper(OpBuilder &builder, Value v, unsigned opIdx, + unsigned numSlices, int64_t sliceWidth, + RankedTensorType tensorType, + bool needCopySliced); }; void Pingponger::updateOpInsertion(Operation *op) { lastInsertedOp = op; } @@ -358,6 +423,208 @@ void Pingponger::determineDotMemoryOps( findClosestPredOps(localStore.getSrc(), dotGlobalLoads); } +// Populate the dotAsyncGlobalLoads, dotLocalLoads, dotAsyncCommitGroups, and +// dotAsyncWaits set with any loads that are generated by the current dot +// product. This occurs in steps to: +// 1. Determine which loads are generated by the dot product via getA() +// and getB(). +// 2. Determine which asyncCopyGlobalToLcals are used to populate the +// inputs to the local loads. +// 3. Determine which async commit are using asyncCopyGlobalToLcals. +// 4. Determine which async waits are consuming async commits +// Note: This function currently depends on num_stages=2, which is a +// precondition for the pingpong scheduling. +void Pingponger::determineDotAsyncMemoryOps( + tt::DotOp dotOp, + DenseSet &dotAsyncGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotAsyncCommitGroups, + DenseSet &dotAsyncWaits) { + // Find the locals loads used to compute the dot inputs. These + // must come before the dot op. + findClosestPredOps(dotOp.getA(), dotLocalLoads); + findClosestPredOps(dotOp.getB(), dotLocalLoads); + + // Determine the local stores from the local loads. + // With pipelining we expect this to be a single local + // store within the loop based on a block argument after routing through + // a ttg.MemDescSubviewOp. + DenseSet subviews; + for (auto &&localLoad : dotLocalLoads) + findClosestPredOps(localLoad.getSrc(), subviews); + + for (auto &&subview : subviews) { + for (auto &&user : subview->getUsers()) { + if (auto globalLoad = dyn_cast(user)) { + if (!globalLoad->hasOneUse()) + continue; + auto asyncCommitGroup = + dyn_cast(*globalLoad->getUsers().begin()); + if (!asyncCommitGroup) + continue; + + dotAsyncGlobalLoads.insert(globalLoad); + dotAsyncCommitGroups.insert(asyncCommitGroup); + } + } + } + + // Looks for AsyncWaitOp, which after StreamPipeliner should be + // located/consumed by the iter arg which represent the AsyncCommits. + for (auto &&asyncCommitGroup : dotAsyncCommitGroups) { + if (!asyncCommitGroup->hasOneUse()) + return; + auto asyncWaitOp = + dyn_cast(*asyncCommitGroup->getUsers().begin()); + if (!asyncWaitOp) + return; + dotAsyncWaits.insert(asyncWaitOp); + } +} + +LogicalResult +Pingponger::pruneDotMemoryOps(DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotLocalStores, + bool assumeNotTaken) { + // Prune Memory operations that may be moved to only those involved in dot + // computation. To understand the "cluster assumptions" we also estimate + // the impact of any additional loads/stores. + auto gLoadIt = std::stable_partition( + gLoadOps.begin(), gLoadOps.end(), + [&dotGlobalLoads](tt::LoadOp op) { return dotGlobalLoads.contains(op); }); + auto lLoadIt = std::stable_partition(lLoadOps.begin(), lLoadOps.end(), + [&dotLocalLoads](ttg::LocalLoadOp op) { + return dotLocalLoads.contains(op); + }); + auto lStoreIt = + std::stable_partition(lStoreOps.begin(), lStoreOps.end(), + [&dotLocalStores](ttg::LocalStoreOp op) { + return dotLocalStores.contains(op); + }); + + if (estimateNonDotMemoryImpact(gLoadIt, gLoadOps.end(), + assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot global loads found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact(lLoadIt, lLoadOps.end(), + assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local loads found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact(lStoreIt, lStoreOps.end(), + assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local stores found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + + // Remove non-dot memory operations. + gLoadOps.erase(gLoadIt, gLoadOps.end()); + lLoadOps.erase(lLoadIt, lLoadOps.end()); + lStoreOps.erase(lStoreIt, lStoreOps.end()); + // All PingPong Scheduler assumes there are 2 movable global loads and 2 + // movable local loads. + if (gLoadOps.size() != 2 || lLoadOps.size() != 2) { + std::stringstream message; + message << "Unable to match ping pong slicing pattern. Details: " + << gLoadOps.size() << " global loads in dot computation, " + << lLoadOps.size() << " local loads in dot computation"; + LDBG(message.str()); + return failure(); + } + return success(); +} + +LogicalResult Pingponger::pruneDotAsyncMemoryOps( + DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotAsyncCommitGroups, + DenseSet &dotAsyncWaits, bool assumeNotTaken) { + // Prune Memory operations that may be moved to only those involved in dot + // computation. To understand the "cluster assumptions" we also estimate + // the impact of any additional loads/stores. + auto asyncCopyIt = std::stable_partition( + asyncCopyOps.begin(), asyncCopyOps.end(), + [&dotGlobalLoads](ttg::AsyncCopyGlobalToLocalOp op) { + return dotGlobalLoads.contains(op); + }); + auto lLoadIt = std::stable_partition(lLoadOps.begin(), lLoadOps.end(), + [&dotLocalLoads](ttg::LocalLoadOp op) { + return dotLocalLoads.contains(op); + }); + auto asyncCommitIt = std::stable_partition( + asyncCommitOps.begin(), asyncCommitOps.end(), + [&dotAsyncCommitGroups](ttg::AsyncCommitGroupOp op) { + return dotAsyncCommitGroups.contains(op); + }); + auto asyncWaitIt = + std::stable_partition(asyncWaitOps.begin(), asyncWaitOps.end(), + [&dotAsyncWaits](ttg::AsyncWaitOp op) { + return dotAsyncWaits.contains(op); + }); + + if (estimateNonDotMemoryImpact( + asyncCopyIt, asyncCopyOps.end(), assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot global loads found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact(lLoadIt, lLoadOps.end(), + assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local loads found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact( + asyncCommitIt, asyncCommitOps.end(), assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local stores found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact( + asyncWaitIt, asyncWaitOps.end(), assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local stores found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + + // Remove non-dot memory operations. + asyncCopyOps.erase(asyncCopyIt, asyncCopyOps.end()); + lLoadOps.erase(lLoadIt, lLoadOps.end()); + asyncCommitOps.erase(asyncCommitIt, asyncCommitOps.end()); + asyncWaitOps.erase(asyncWaitIt, asyncWaitOps.end()); + // All PingPong Scheduler assumes there are 2 movable global loads and 2 + // movable local loads. + if (asyncCopyOps.size() != 2 || lLoadOps.size() != 2 || + asyncWaitOps.size() != 2) { + std::stringstream message; + message << "Unable to match ping pong slicing pattern. Details: " + << asyncCopyOps.size() << " global loads in dot computation, " + << lLoadOps.size() << " local loads in dot computation"; + LDBG(message.str()); + return failure(); + } + return success(); +} + // Transform a loop into one Dot - Memory (ping - pong) clusters // Each cluster, especially the Dot cluster is guarded with setprio(1->0) so // each warp can complete the execution of the cluster without being @@ -407,12 +674,11 @@ void Pingponger::genOffsetConstants(Location loc, OpBuilder &builder, // generates ops when succeed, return fail() otherwise. LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, - unsigned numSlices, - int64_t sliceWidth) { - SmallVector slices; - SmallVector subviews; + unsigned numSlices, int64_t sliceWidth, + bool needCopySliced) { // TODO: support transformed input to dot auto localLoad = v.getDefiningOp(); + if (!localLoad) return failure(); auto memDesc = localLoad.getSrc(); @@ -426,24 +692,114 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, return failure(); auto dotOperandEnc = ttg::DotOperandEncodingAttr::get( builder.getContext(), opIdx, dotEncoding, kWidth); + + auto tensorType = RankedTensorType::get(shape, elementType, dotOperandEnc); + + return genLocalSliceHelper(builder, v, opIdx, numSlices, sliceWidth, + tensorType, needCopySliced); +} + +LogicalResult Pingponger::genLocalSliceScales( + OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, + unsigned numSlices, int64_t sliceWidth, bool needCopySliced) { + auto localLoad = v.getDefiningOp(); + if (!localLoad) + return failure(); + auto memDesc = localLoad.getSrc(); + auto type = cast(memDesc.getType()); + SmallVector shape = llvm::to_vector(type.getShape()); + Type elementType = type.getElementType(); + int64_t kIdx = opIdx == 0 ? 1 : 0; + shape[kIdx] = sliceWidth; + + auto ll = mlir::triton::gpu::toLinearLayout(shape, dotEncoding); + auto dotOperandEnc = ttg::LinearEncodingAttr::get(type.getContext(), ll); + auto tensorType = RankedTensorType::get(shape, elementType, dotOperandEnc); + + return genLocalSliceHelper(builder, v, 0, numSlices, sliceWidth, tensorType, + needCopySliced); +} + +LogicalResult Pingponger::genLocalSliceHelper( + OpBuilder &builder, Value v, unsigned opIdx, unsigned numSlices, + int64_t sliceWidth, RankedTensorType tensorType, bool needCopySliced) { + auto localLoad = v.getDefiningOp(); + if (!localLoad) + return failure(); + + auto waitToken = localLoad.getToken(); + auto memDesc = localLoad.getSrc(); + auto type = cast(memDesc.getType()); + SmallVector shape = llvm::to_vector(type.getShape()); + Type elementType = type.getElementType(); + int64_t kIdx = opIdx == 0 ? 1 : 0; + shape[kIdx] = sliceWidth; + + auto resEncoding = localLoad.getResult().getType().getEncoding(); + auto dotOperandResEncoding = + dyn_cast(resEncoding); + const bool refineOrigSubview = dotOperandResEncoding != nullptr; + + auto arg = mlir::dyn_cast(memDesc); + if (!arg) { + LDBG("failed to cast input to `ttg.LocalLoadOp` to `BlockArgument`"); + return failure(); + } + + auto forOp = localLoad->getParentOfType(); + auto argIdx = arg.getArgNumber(); + auto yieldOperand = forOp.getTiedLoopYieldedValue(arg); + auto yieldOp = cast(yieldOperand->getOwner()); + auto origMemDesc = + cast(yieldOperand->get().getDefiningOp()); + auto subviewDescType = ttg::MemDescType::get( shape, elementType, type.getEncoding(), type.getMemorySpace(), type.getMutableMemory(), type.getAllocShape()); + + SmallVector slices; + SmallVector subviews; + MLIRContext *ctx = localLoad->getContext(); + auto intType = mlir::IntegerType::get(ctx, 32); for (int i = 0; i < numSlices; i++) { + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(&forOp.front()); + SmallVector offsetsVal; SmallVector offsets = {0, 0}; - offsets[kIdx] = i; + offsets[opIdx == 0 ? 1 : 0] = i; for (int64_t off : offsets) { - offsetsVal.push_back(constOffsets[off]); + offsetsVal.push_back(builder.create( + v.getLoc(), off * sliceWidth, 32)); + } + + builder.setInsertionPointAfter(origMemDesc); + auto sliceIdAttr = mlir::IntegerAttr::get(intType, i); + if (needCopySliced && refineOrigSubview) { + Value newOrigSmem = builder.create( + origMemDesc.getLoc(), subviewDescType, origMemDesc, offsetsVal); + + // set attributes - i.e., which dot-operand, which slice + newOrigSmem.getDefiningOp()->setAttr(Pingponger::sliceAttrName, + sliceIdAttr); + newOrigSmem.getDefiningOp()->setAttr( + triton::amdgpu::OpIdxAttr::getMnemonic(), + triton::amdgpu::OpIdxAttr::get(ctx, + dotOperandResEncoding.getOpIdx())); } + builder.restoreInsertionPoint(ip); + Value newSmem = builder.create( v.getLoc(), subviewDescType, memDesc, offsetsVal); Value prefetchSlice = builder.create( - v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), - newSmem); + v.getLoc(), tensorType, newSmem, waitToken); + + prefetchSlice.getDefiningOp()->setAttr(Pingponger::sliceAttrName, + sliceIdAttr); subviews.push_back(newSmem.getDefiningOp()); slices.push_back(prefetchSlice.getDefiningOp()); } + subViewOps.push_back(subviews); loadSliceOps.push_back(slices); return success(); @@ -461,11 +817,14 @@ LogicalResult Pingponger::sliceDot(OpBuilder &builder, Location loc, if (shapeB[0] % numSlices != 0) return failure(); genOffsetConstants(loc, builder, numSlices, sliceWidth); - builder.setInsertionPointAfter(gLoadOps[0]); + builder.setInsertionPointAfter(useAsyncCopy ? asyncCopyOps[0] : gLoadOps[0]); auto dotEncoding = op.getType().getEncoding(); - if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth) + const bool needCopySliced = false; + if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth, + needCopySliced) .failed() || - genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth) + genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth, + needCopySliced) .failed()) return failure(); @@ -488,72 +847,568 @@ LogicalResult Pingponger::sliceDot(OpBuilder &builder, Location loc, return success(); } -// Transform a loop into four Dot - Memory (ping - pong) clusters -// This transform is useful when the original dot tile is too large that there's -// not enough registers to hold data for a Dot cluster. This path slices the dot -// into four pieces and pair with four clusters of reordered memory operations. -// There are multiple guards at the boundary of each cluster. -// (1) sched.barrier : with mask0 to prevent compiler backed from reordering -// instructions across the boundary -// (2) gpu.barrier : ensures asymmetric synchronization at each point -// (3) setprio (1->0) : in order to avoid incoming warp overtaking resource -// while the other warp is actively using it. -// -// Here's overview of the instruction clusters -// mem0: global load A, local load A(1/4), local load B(1/4) -// dot0: dot A(1/4) * B(1/4) -// mem1: global load B, local load A(2/4), local load B(2/4) -// dot1: dot A(2/4) * B(2/4) -// mem2: local load A(3/4, 4/4), local load B(3/4, 4/4) -// dot2: dot A(3/4) * B(3/4) -// mem3: local store A and B -// dot3: dot A(4/4) * B(4/4) +LogicalResult Pingponger::genAsyncCopySlices(OpBuilder &builder) { + if (asyncCopyOps.empty()) + return success(); -LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, - Location loc) { - // First, slice local_loads and dot into 4 parts - if (sliceDot(builder, loc, dotOps[0], 4).failed()) - return failure(); - builder.setInsertionPointAfter(gLoadOps[1]); - // Reorder operations into four mem/dot clusters + auto &loopBody = forOp.getRegion().front(); + auto yieldOp = cast(loopBody.getTerminator()); - // mem0: global load A, local load A(1/4), local load B(1/4) - // set insertion point at the last global_load where all the addresses are - // ready to be used. - updateOpInsertion(gLoadOps[1]); - appendSlicedLoadAB(/*slice=*/0); - appendClusterBarrier(builder, loc); + for (auto asyncCopy : asyncCopyOps) { + MLIRContext *ctx = asyncCopy->getContext(); + auto srcPointers = asyncCopy.getSrc(); + auto subView = cast( + asyncCopy.getResult().getDefiningOp()); + auto subViewEncoding = subView.getType().getEncoding(); - // dot0 (1/4) - appendOpWithPrio(builder, dotSliceOps[0], loc); - appendClusterBarrier(builder, loc); + DenseMap subViews; + for (auto user : subView->getUsers()) { + if (auto subView = dyn_cast(user)) { + if (auto attr = subView->getAttrOfType( + Pingponger::sliceAttrName)) { + auto sliceIdx = attr.getValue().getSExtValue(); + subViews.insert({sliceIdx, subView}); - // mem1: global load B, local load A(2/4), local load B(2/4) - appendOp(gLoadOps[1]); - appendSlicedLoadAB(/*slice=*/1); - appendClusterBarrier(builder, loc); + if (!newAsyncGroups.contains(sliceIdx)) { + newAsyncGroups.insert({sliceIdx, {}}); + } + } + } + } - // dot1 (2/4) - appendOpWithPrio(builder, dotSliceOps[1], loc); - appendClusterBarrier(builder, loc); + if (subViews.empty()) { + auto commit = + getSingleUserOf(asyncCopy.getToken()); + preservedAsyncCommits.insert(commit->getResult()); + continue; + } - // mem2: local load A(3/4, 4/4), local load B(3/4, 4/4) - appendSlicedLoadAB(/*slice=*/2); - appendSlicedLoadAB(/*slice=*/3); - appendClusterBarrier(builder, loc); + // infer the sliced shape + triton::gpu::MemDescSubviewOp subViewSlice = subViews[0]; + auto origShape = subView.getType().getShape(); + auto slicedShape = subViewSlice.getType().getShape(); + assert(origShape.size() == slicedShape.size()); + const auto numDims = origShape.size(); + int64_t slicedDim = -1; + for (size_t dim = 0; dim < numDims; ++dim) { + if (origShape[dim] != slicedShape[dim]) { + slicedDim = dim; + break; + } + } - // dot2 (3/4) - appendOpWithPrio(builder, dotSliceOps[2], loc); - appendClusterBarrier(builder, loc); + builder.setInsertionPointAfter(asyncCopy); - // mem3: local store A and B - // Matmul kernels may use the output of the dot product in another operation - // before the local store (e.g. persistent matmul epilogue). To accommodate - // such cases, we need to move the local store up in the loop. - moveOpAndPredecessorsUpSameBlock(lStoreOps[0]); - moveOpAndPredecessorsUpSameBlock(lStoreOps[1]); - appendClusterBarrier(builder, loc); + auto elementType = srcPointers.getType().getElementType(); + auto encoding = + cast(srcPointers.getType().getEncoding()); + auto warpsPerCTA = encoding.getWarpsPerCTA(); + auto sizePerThread = encoding.getSizePerThread(); + SmallVector threadPerWarp(warpsPerCTA.size(), 0); + for (size_t dim = 0; dim < numDims; ++dim) { + threadPerWarp[dim] = + slicedShape[dim] / (warpsPerCTA[dim] * sizePerThread[dim]); + } + assert(mlir::product(threadPerWarp) == 64); + + auto newEncoding = ttg::BlockedEncodingAttr::get( + ctx, sizePerThread, threadPerWarp, warpsPerCTA, encoding.getOrder(), + encoding.getCTALayout()); + + auto convertTensor = [&](mlir::TypedValue tensor) { + RankedTensorType newType = nullptr; + Value newTensor = nullptr; + RankedTensorType slicedTensorType = nullptr; + if (tensor) { + assert(encoding == tensor.getType().getEncoding()); + auto elemType = tensor.getType().getElementType(); + newType = RankedTensorType::get(origShape, elemType, newEncoding); + newTensor = + builder + .create(tensor.getLoc(), newType, tensor) + .getResult(); + slicedTensorType = + RankedTensorType::get(slicedShape, elemType, newEncoding); + } + + return std::make_tuple(newType, newTensor, slicedTensorType); + }; + + mlir::TypedValue origMask = nullptr; + mlir::TypedValue origOtherTensor = nullptr; + + if (auto value = asyncCopy.getMask()) { + origMask = dyn_cast(value); + } + if (auto value = asyncCopy.getOther()) { + origOtherTensor = cast(value); + } + + auto [newSrcType, newSrcPointers, slicedSrcType] = + convertTensor(srcPointers); + auto [newMaskType, newMask, slicedMaskType] = convertTensor(origMask); + auto [newOtherType, newOther, slicedOtherType] = + convertTensor(origOtherTensor); + + auto extract = [&builder](Type resType, Value src, + DenseI64ArrayAttr &offset) { + Value resValue = nullptr; + if (src) { + resValue = builder.create( + src.getLoc(), resType, src, offset); + } + return resValue; + }; + + auto origSubViewType = subView.getType(); + auto subViewDescType = ttg::MemDescType::get( + slicedShape, origSubViewType.getElementType(), + origSubViewType.getEncoding(), origSubViewType.getMemorySpace(), + origSubViewType.getMutableMemory(), + subView.getSrc().getType().getShape()); + Value subViewSelector = subView.getOffsets().front(); + + assert(slicedDim != -1); + SmallVector newCommits; + auto numReps = origShape[slicedDim] / slicedShape[slicedDim]; + for (size_t rep = 0; rep < numReps; ++rep) { + SmallVector offset(slicedShape.size(), 0); + offset[slicedDim] = slicedShape[slicedDim] * rep; + auto offsetAttr = DenseI64ArrayAttr::get(ctx, offset); + + auto extractedSrc = extract(slicedSrcType, newSrcPointers, offsetAttr); + auto extractedMask = extract(slicedMaskType, newMask, offsetAttr); + auto extractedOther = extract(slicedOtherType, newOther, offsetAttr); + + SmallVector newSubviewOffset = {subViewSelector}; + llvm::for_each(offset, [&](auto off) { + newSubviewOffset.push_back( + builder.create(subView.getLoc(), off, 32)); + }); + + auto newSlicedSubView = builder.create( + subView.getLoc(), subViewDescType, subView.getSrc(), + newSubviewOffset); + + auto newAsyncCopy = builder.create( + asyncCopy->getLoc(), extractedSrc, + Value{newSlicedSubView.getResult()}, extractedMask, extractedOther, + asyncCopy.getCache(), asyncCopy.getEvict(), + asyncCopy.getIsVolatile()); + + auto newCommit = builder.create( + asyncCopy->getLoc(), newAsyncCopy.getToken()); + + // propagate all attributes from `mem-view` to the commit token + newSlicedSubView->setAttrs(subViews[rep]->getAttrs()); + newAsyncCopy->setAttrs(subViews[rep]->getAttrs()); + newCommit->setAttrs(subViews[rep]->getAttrs()); + + newAsyncGroups[rep].push_back(newCommit); + newCommits.push_back(newCommit); + + subViews[rep]->erase(); + } + + auto origCommitGroup = getSingleUserOf(asyncCopy); + auto maybeResultIdx = getIndex(yieldOp, origCommitGroup->getResult()); + assert(maybeResultIdx.has_value()); + auto origYieldOperand = yieldOp->getOperand(maybeResultIdx.value()); + asyncTokenReassociation.insert({origYieldOperand, newCommits}); + } + + return success(); +} + +LogicalResult Pingponger::updateForOpSignature(OpBuilder &builder) { + // Note: call this method at the very end when reference to the + // original ops are not needed anymore + + if (asyncCopyOps.empty()) + return llvm::success(); + + Block &oldBlock = forOp.getRegion().front(); + auto origYieldOp = cast(oldBlock.getTerminator()); + auto orgiInitArgs = forOp.getInitArgs(); + + SmallVector newInputArgTokens; + for (auto &[origCommit, newCommits] : asyncTokenReassociation) { + auto maybeIdx = getIndex(origYieldOp, origCommit); + assert(maybeIdx.has_value()); + auto initCommitArgValue = orgiInitArgs[maybeIdx.value()]; + auto initCommitOp = + cast(initCommitArgValue.getDefiningOp()); + builder.setInsertionPointAfter(initCommitOp); + for (size_t i = 0; i < newCommits.size(); ++i) { + auto newInputArgToken = builder.create( + initCommitOp->getLoc(), initCommitOp.getAsyncToken().getType(), + initCommitOp.getInputTokens()); + newInputArgTokens.push_back(newInputArgToken); + } + } + + builder.setInsertionPointAfter(forOp); + DenseSet preservedArgsIndices; + DenseSet removedArgsIndices; + auto origYeildValues = forOp.getYieldedValues(); + for (auto [idx, value] : llvm::enumerate(origYeildValues)) { + bool copyable = dyn_cast(value.getType()) == nullptr; + copyable |= preservedAsyncCommits.contains(value); + if (copyable) { + preservedArgsIndices.insert(idx); + } else { + removedArgsIndices.insert(idx); + } + } + + DenseMap argIndicesMap; + SmallVector newInitArgs; + for (auto [idx, value] : llvm::enumerate(orgiInitArgs)) { + if (preservedArgsIndices.contains(idx)) { + argIndicesMap.insert({idx, newInitArgs.size()}); + newInitArgs.push_back(value); + } + } + + for (auto newInputToken : newInputArgTokens) { + newInitArgs.push_back(newInputToken); + } + + // Create a new ForOp + scf::ForOp newForOp = builder.create( + forOp->getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInitArgs); + + // Map original block arguments to new ones + Block &newBlock = newForOp.getRegion().front(); + + IRMapping mapping; + auto oldIterArgs = forOp.getRegionIterArgs(); + auto newIterArgs = newForOp.getRegionIterArgs(); + mapping.map(oldBlock.getArgument(0), newBlock.getArgument(0)); // loop index + for (auto [origIdx, newIdx] : argIndicesMap) { + mapping.map(oldIterArgs[origIdx], newIterArgs[newIdx]); + } + + // Clone the body of the loop + builder.setInsertionPointToStart(&newBlock); + for (auto &op : oldBlock.without_terminator()) { + builder.clone(op, mapping); + } + + // Clone the yield terminator + builder.setInsertionPointToEnd(&newBlock); + SmallVector newYieldResults; + for (auto [idx, value] : llvm::enumerate(forOp.getYieldedValues())) { + if (preservedArgsIndices.contains(idx)) { + newYieldResults.push_back(mapping.lookup(value)); + } + } + + for (auto &[origCommit, newCommits] : asyncTokenReassociation) { + for (auto commit : newCommits) + newYieldResults.push_back(mapping.lookup(commit)); + } + + builder.create(origYieldOp.getLoc(), newYieldResults); + + auto newForOpResults = newForOp.getResults(); + DenseSet adjustedUsers; + for (auto [idx, oldResult] : llvm::enumerate(forOp->getResults())) { + if (preservedArgsIndices.contains(idx)) { + auto newArgIdx = argIndicesMap[idx]; + oldResult.replaceAllUsesWith(newForOpResults[newArgIdx]); + } else { + for (auto user : oldResult.getUsers()) { + adjustedUsers.insert(user); + } + } + } + + // Adjust async-wait outside the newForOp + assert(adjustedUsers.size() == 1); + auto asyncWaitEpilogue = dyn_cast(*adjustedUsers.begin()); + assert(asyncWaitEpilogue != nullptr); + + builder.setInsertionPointAfter(asyncWaitEpilogue); + SmallVector newOperands; + for (auto newResult : newForOp->getResults()) { + if (dyn_cast(newResult.getType())) { + newOperands.push_back(newResult); + } + } + auto newAsyncWaitEpilogue = builder.create( + asyncWaitEpilogue->getLoc(), newOperands, 0); + asyncWaitEpilogue->replaceAllUsesWith(newAsyncWaitEpilogue); + asyncWaitEpilogue->erase(); + + SmallVector newAsyncTokens; + for (auto &arg : newForOp.getRegionIterArgs()) { + if (dyn_cast(arg.getType())) + newAsyncTokens.push_back(arg); + } + + // adjust async-wait inside the newForOp block + ttg::AsyncWaitOp asyncWait = nullptr; + newForOp.walk([&asyncWait](ttg::AsyncWaitOp op) { + asyncWait = op; + return WalkResult::interrupt(); + }); + assert(asyncWait != nullptr); + builder.setInsertionPointAfter(asyncWait); + auto newAsyncToken = + builder.create(asyncWait->getLoc(), newAsyncTokens, 0); + asyncWait->replaceAllUsesWith(newAsyncToken); + asyncWait.erase(); + + this->forOp->erase(); + this->forOp = newForOp; + return success(); +} + +LogicalResult Pingponger::adjustRefinedAsyncTokens(OpBuilder &builder) { + auto yeildOp = cast(forOp.getBody()->getTerminator()); + auto forOpArgs = forOp.getRegionIterArgs(); + + DenseMap, Value> refinedTokens; + SmallVector nonRefinedTokens; + forOp->walk([&](ttg::AsyncCommitGroupOp commit) { + auto tokenIdx = getIndex(yeildOp, commit.getResult()); + if (!tokenIdx.has_value()) + return WalkResult::advance(); + + int32_t opIdx = -1; + if (auto attr = commit->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + opIdx = attr.getValue(); + } + + int32_t sliceId = -1; + if (auto attr = commit->getAttrOfType( + Pingponger::sliceAttrName)) { + sliceId = attr.getValue().getSExtValue(); + } + + assert(tokenIdx.has_value()); + auto asyncToken = forOpArgs[tokenIdx.value()]; + if ((opIdx > -1) && (sliceId > -1)) { + refinedTokens.insert({{opIdx, sliceId}, asyncToken}); + } else { + nonRefinedTokens.push_back(asyncToken); + } + return WalkResult::advance(); + }); + + // leave only `scaleA` and `scaleB` wait-tokens + ttg::AsyncWaitOp origAsyncWait; + forOp->walk([&origAsyncWait](ttg::AsyncWaitOp wait) { + origAsyncWait = wait; + return WalkResult::interrupt(); + }); + builder.setInsertionPointAfter(origAsyncWait); + auto newAsyncWait = builder.create(origAsyncWait->getLoc(), + nonRefinedTokens, 0); + origAsyncWait->replaceAllUsesWith(newAsyncWait); + origAsyncWait->erase(); + + // collect all refined localLoads + DenseMap, ttg::LocalLoadOp> refinedLocalLoads; + forOp->walk([&](ttg::LocalLoadOp localLoad) { + int32_t opIdx = -1; + auto resultType = cast(localLoad.getResult().getType()); + if (auto encding = + dyn_cast(resultType.getEncoding())) { + opIdx = encding.getOpIdx(); + } + + int32_t sliceId = -1; + if (auto attr = localLoad->getAttrOfType( + Pingponger::sliceAttrName)) { + sliceId = attr.getValue().getSExtValue(); + } + + if ((opIdx > -1) && (sliceId > -1)) { + refinedLocalLoads.insert({{opIdx, sliceId}, localLoad}); + } + }); + + // create new local load preceeded by new wait-tokens + for (auto &item : refinedTokens) { + auto [opIdx, sliceIdx] = item.first; + auto commit = item.second; + if (!refinedLocalLoads.contains({opIdx, sliceIdx})) + continue; + auto localLoad = refinedLocalLoads[{opIdx, sliceIdx}]; + builder.setInsertionPointAfter(localLoad); + auto token = builder.create(localLoad->getLoc(), + ValueRange{commit}, 0); + auto newLocalLoad = builder.create( + localLoad->getLoc(), localLoad.getResult().getType(), + localLoad.getSrc(), token); + localLoad->replaceAllUsesWith(newLocalLoad); + localLoad->erase(); + } + + return success(); +} + +LogicalResult Pingponger::sliceDotScaled(OpBuilder &builder, Location loc, + tt::DotScaledOp op, + unsigned numSlices) { + builder.setInsertionPointToStart(forOp.getBody()); + auto typeB = op.getB().getType(); + auto typeScaleB = op.getBScale().getType(); + auto shapeB = typeB.getShape(); + auto shapeScaleB = typeScaleB.getShape(); + + int64_t sliceWidth = shapeB[0] / numSlices; + int64_t sliceScaleWidth = shapeScaleB[1] / numSlices; + if (shapeB[1] % numSlices != 0) + return failure(); + + if (!gLoadOps.empty()) + builder.setInsertionPointAfter(gLoadOps[0]); + else if (!asyncCopyOps.empty()) { + builder.setInsertionPointAfter(asyncCopyOps[0]); + } else { + return failure(); + } + auto dotEncoding = op.getType().getEncoding(); + + bool needCopySliced = true; + // Generate slices for operands A and B + if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth, + needCopySliced) + .failed() || + genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth, + needCopySliced) + .failed()) + return failure(); + + // Generate slices for scale tensors if they exist + Value aScale = op.getAScale(); + Value bScale = op.getBScale(); + + needCopySliced = false; + if (aScale) { + if (genLocalSliceScales(builder, aScale, + op.getAScale().getType().getEncoding(), 0, + numSlices, sliceScaleWidth, needCopySliced) + .failed()) + return failure(); + } + + if (bScale) { + if (genLocalSliceScales(builder, bScale, + op.getBScale().getType().getEncoding(), 0, + numSlices, sliceScaleWidth, needCopySliced) + .failed()) + return failure(); + } + + Operation *prevDot = op; + for (int i = 0; i < numSlices; i++) { + IRMapping mapping; + mapping.map(op.getA(), loadSliceOps[0][i]->getResult(0)); + mapping.map(op.getB(), loadSliceOps[1][i]->getResult(0)); + + // Map scale tensors if they exist + if (aScale) + mapping.map(op.getAScale(), loadSliceOps[2][i]->getResult(0)); + if (bScale) + mapping.map(op.getBScale(), loadSliceOps[3][i]->getResult(0)); + + if (i > 0) + mapping.map(op.getC(), prevDot->getResult(0)); + + auto newOp = builder.clone(*op, mapping); + prevDot = newOp; + dotSliceOps.push_back(newOp); + } + + // Replace original op with the last slice and cleanup + op->replaceAllUsesWith(prevDot); + op->erase(); + for (auto loads : lLoadOps) + loads->erase(); + return success(); +} + +// Transform a loop into four Dot - Memory (ping - pong) clusters +// This transform is useful when the original dot tile is too large that there's +// not enough registers to hold data for a Dot cluster. This path slices the dot +// into four pieces and pair with four clusters of reordered memory operations. +// There are multiple guards at the boundary of each cluster. +// (1) sched.barrier : with mask0 to prevent compiler backed from reordering +// instructions across the boundary +// (2) gpu.barrier : ensures asymmetric synchronization at each point +// (3) setprio (1->0) : in order to avoid incoming warp overtaking resource +// while the other warp is actively using it. +// +// Here's overview of the instruction clusters +// mem0: global load A, local load A(1/4), local load B(1/4) +// dot0: dot A(1/4) * B(1/4) +// mem1: global load B, local load A(2/4), local load B(2/4) +// dot1: dot A(2/4) * B(2/4) +// mem2: local load A(3/4, 4/4), local load B(3/4, 4/4) +// dot2: dot A(3/4) * B(3/4) +// mem3: local store A and B +// dot3: dot A(4/4) * B(4/4) + +LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, + Location loc) { + // First, slice local_loads and dot into 4 parts + if (sliceDot(builder, loc, dotOps[0], 4).failed()) + return failure(); + Operation *gLoadRhs = useAsyncCopy ? asyncCopyOps[1] : gLoadOps[1]; + builder.setInsertionPointAfter(gLoadRhs); + // Reorder operations into four mem/dot clusters + + // mem0: global load A, local load A(1/4), local load B(1/4) + // set insertion point at the last global_load where all the addresses are + // ready to be used. + updateOpInsertion(gLoadRhs); + appendSlicedLoadAB(/*slice=*/0); + appendClusterBarrier(builder, loc); + + // dot0 (1/4) + appendOpWithPrio(builder, dotSliceOps[0], loc); + appendClusterBarrier(builder, loc); + + // mem1: global load B, local load A(2/4), local load B(2/4) + appendOp(gLoadRhs); + if (useAsyncCopy) { + appendOp(asyncCommitOps[1]); + } + appendSlicedLoadAB(/*slice=*/1); + appendClusterBarrier(builder, loc); + + // dot1 (2/4) + appendOpWithPrio(builder, dotSliceOps[1], loc); + appendClusterBarrier(builder, loc); + + // mem2: local load A(3/4, 4/4), local load B(3/4, 4/4) + appendSlicedLoadAB(/*slice=*/2); + appendSlicedLoadAB(/*slice=*/3); + appendClusterBarrier(builder, loc); + + // dot2 (3/4) + appendOpWithPrio(builder, dotSliceOps[2], loc); + appendClusterBarrier(builder, loc); + + // mem3: local store A and B + // Matmul kernels may use the output of the dot product in another operation + // before the local store (e.g. persistent matmul epilogue). To accommodate + // such cases, we need to move the local store up in the loop. + if (!useAsyncCopy) { + moveOpAndPredecessorsUpSameBlock(lStoreOps[0]); + moveOpAndPredecessorsUpSameBlock(lStoreOps[1]); + appendClusterBarrier(builder, loc); + } else { + appendOp(asyncWaitOps[0]); + appendOp(asyncWaitOps[1]); + } // dot3 (4/4) appendOpWithPrio(builder, dotSliceOps[3], loc); @@ -625,6 +1480,141 @@ LogicalResult Pingponger::transformTwoPPClusters(OpBuilder &builder, return success(); } +// Fixme : document the scheduling. +// Assuming pipeliner already ordered the ops. +LogicalResult Pingponger::transformFAv3(OpBuilder &builder, Location loc) { + if (asyncWaitOps.size() != 2) { + return llvm::failure(); + } + + builder.setInsertionPointToStart(forOp.getBody()); + updateOpInsertion(dotOps[0]); + prependOp(builder.create(loc, lowPriority), false); + + // dot cluster 0 operations here. + + updateOpInsertion(asyncWaitOps[0]); + prependOp(builder.create(loc, highPriority), false); + appendOp(builder.create(loc, 0)); + + // mem cluster 0 operations here. + + updateOpInsertion(dotOps[1]); + // below ops are inserted backward + prependOp(builder.create(loc, lowPriority), true); + prependOp(builder.create(loc), true); + prependOp(builder.create(loc, 0), true); + + // dot cluster 1 operations here. + + updateOpInsertion(asyncWaitOps[1]); + prependOp(builder.create(loc, highPriority), false); + appendOp(builder.create(loc, 0)); + + // mem cluster 1 operations here. + + updateOpInsertion(lastInsertedOp->getBlock()->getTerminator()); + prependOp(builder.create(loc), true); + prependOp(builder.create(loc, 0), true); + + // Fixme: validate the case here? + return success(); +} + +LogicalResult Pingponger::transformFP4s(OpBuilder &builder, Location loc) { + // FIXME: support nonscale. + if (lLoadOps.size() != 4) + return failure(); + + auto tokens = asyncWaitOps[0].getAsyncToken(); + Operation *aWait = asyncWaitOps[0]; + builder.setInsertionPointToStart(forOp.getBody()); + asyncWaitOps.clear(); + for (int i = 0; i < 2; i++) { + auto newOp = builder.clone(*aWait); + newOp->eraseOperand(3 - i); + newOp->eraseOperand(1 - i); + asyncWaitOps.push_back(cast(newOp)); + } + lLoadOps[0]->replaceUsesOfWith(aWait->getResult(0), asyncWaitOps[0]); + lLoadOps[2]->replaceUsesOfWith(aWait->getResult(0), asyncWaitOps[0]); + lLoadOps[1]->replaceUsesOfWith(aWait->getResult(0), asyncWaitOps[1]); + lLoadOps[3]->replaceUsesOfWith(aWait->getResult(0), asyncWaitOps[1]); + aWait->erase(); + + builder.setInsertionPointAfter(dotSOps[0]); + updateOpInsertion(dotSOps[0]); + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc)); + appendOp(builder.create(loc, 0)); + appendOp(lLoadOps[0]); + appendOp(lLoadOps[2]); + + appendOp(asyncWaitOps[1]); + + appendOp(asyncCopyOps[1]); + appendOp(asyncCopyOps[3]); + appendOp(asyncCommitOps[1]); + appendOp(asyncCommitOps[3]); + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc)); + appendOp(builder.create(loc, 0)); + + appendOp(lLoadOps[1]); + appendOp(lLoadOps[3]); + appendOp(dotSOps[0]); + + return success(); +} + +LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { + + builder.setInsertionPointAfter(forOp); + + // FIXME: This is duplicated code, need to refactorize. + auto i32ty = builder.getIntegerType(32); + auto workIDX = builder.create(loc, i32ty); + workIDX->moveBefore(forOp); + builder.setInsertionPointAfter(workIDX); + auto constZero = builder.create(loc, 0, 32); + auto constWarpSize = builder.create(loc, 256, 32); + auto warpIDX = builder.create(loc, workIDX, constWarpSize); + auto warpLow = builder.create(loc, arith::CmpIPredicate::eq, + warpIDX, constZero); + auto warpHigh = builder.create(loc, arith::CmpIPredicate::ne, + warpIDX, constZero); + + builder.setInsertionPointAfter(dotSOps[0]); + if (sliceDotScaled(builder, loc, dotSOps[0], 4).failed()) + return failure(); + + if (genAsyncCopySlices(builder).failed()) { + LDBG("failed to slice global-to-local async copies"); + return failure(); + } + + updateOpInsertion(dotSliceOps[0]); + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc, warpLow)); + appendOp(builder.create(loc, 0)); + for (int j = 0; j < 4; j++) { + for (int i = 0; i < 4; i++) + appendOp(subViewOps[i][j]); + for (int i = 0; i < 4; i++) + appendOp(loadSliceOps[i][j]); + appendOp(builder.create(loc, 0)); + appendOp(dotSliceOps[j]); + } + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc, warpHigh)); + + return success(); +} + // This function wraps forOp with cond_barrier. First, hold half of the warps // (warpHigh) in a block before the loop so the barriers in the loop synchronize // warps at the different point per the warp groups. After the loop, hold @@ -657,10 +1647,10 @@ void Pingponger::addAsymmetricSyncToLoop(OpBuilder &builder, Location loc) { } void Pingponger::getDotPingponged() { - if (numStages != 2) { + if (numStages != 2 && numStages != 4) { std::stringstream message; - message << "All ping pong scheduling requires 2 stages. Found " << numStages - << " stages"; + message << "All ping pong scheduling requires 2 or 4 stages. Found " + << numStages << " stages"; LDBG(message.str()); return; } @@ -669,9 +1659,12 @@ void Pingponger::getDotPingponged() { MLIRContext *ctx = forOp.getContext(); Location loc = forOp.getLoc(); + SmallVector asyncWaitsOps; forOp->walk([&](Operation *op) { if (auto gLoad = dyn_cast(op)) gLoadOps.push_back(gLoad); + if (auto asyncCopy = dyn_cast(op)) + asyncCopyOps.push_back(asyncCopy); else if (auto lLoad = dyn_cast(op)) { // This scheduling doesn't help hiding intra-warp latency. So, we only // collect local_load ops that are software pipelined, which means their @@ -683,29 +1676,191 @@ void Pingponger::getDotPingponged() { lLoadOps.push_back(lLoad); } else if (auto lStore = dyn_cast(op)) lStoreOps.push_back(lStore); - else if (auto pingpongDot = dyn_cast(op)) + else if (auto pingpongDot = dyn_cast(op)) { if (pingpongDot.getType().getRank() == 2) dotOps.push_back(pingpongDot); + } else if (auto pingpongDot = dyn_cast(op)) { + dotSOps.push_back(pingpongDot); + } else if (auto asyncCopy = dyn_cast(op)) { + asyncCopyOps.push_back(asyncCopy); + } else if (auto asyncCommitGroupOp = + dyn_cast(op)) { + asyncCommitOps.push_back(asyncCommitGroupOp); + } else if (auto wait = dyn_cast(op)) { + asyncWaitsOps.push_back(wait); + } }); + const bool isAsyncOpsInUse = !(asyncWaitsOps.empty()); + if (isAsyncOpsInUse && (asyncWaitsOps.size() != 1)) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. " + << "Found " << asyncWaitsOps.size() + << " `AsyncWaitOp` in the scheduled region. Only one is allowed."; + LDBG(message.str()); + return; + } + + // Fixme : use proper condition to identify FAv3 + if (numStages == 4 && dotOps.size() == 2) { + if (transformFAv3(builder, loc).failed()) { + LDBG("Encountered failure when trying to execute the FAv3 ping pong " + "cluster transformation"); + return; + } + addAsymmetricSyncToLoop(builder, loc); + return; + } + // Currently, pingpong scheduling is known as helpful under limited condition. // Individual conditions are checked while collecting each operation such as // software pipelining and dot rank=2. Also only accept the for-loop with // supported combination of operations because this transformation is very // tightly scheduling the latencies. - if (gLoadOps.size() < 2 || lLoadOps.size() < 2 || dotOps.size() != 1) { + + // FIXME: get better condition to enable pingpong either for dot or for + // dot_scaled + int64_t numOfDotLikeOps = dotSOps.size() + dotOps.size(); + if (numOfDotLikeOps != 1) { + LDBG("Only handle a single of either dot or dot_scaled op"); + return; + } + int64_t gloadSize = useAsyncCopy ? asyncCopyOps.size() : gLoadOps.size(); + int64_t dotSize = dotSOps.size() > 0 ? dotSOps.size() : dotOps.size(); + if ((gloadSize < 2 || lLoadOps.size() < 2 || dotSize != 1)) { std::stringstream message; message << "Unable to match ping pong scheduling pattern. Details: " - << gLoadOps.size() << " global loads, " << lLoadOps.size() - << " local loads, " << dotOps.size() << " dot products"; + << gloadSize << " global loads, " << lLoadOps.size() + << " local loads, " << dotSize << " dot products"; LDBG(message.str()); return; } + // FIXME: place tile size restriction here and obtain kWidth + if (dotSOps.size() == 1 && numWarps == 8 && numStages == 2 && + asyncCopyOps.size() > 0) { + auto dotSType = dotSOps[0].getType(); + auto dotSShape = dotSType.getShape(); + auto aType = dotSOps[0].getA().getType(); + auto aShape = aType.getShape(); + auto elemWidth = aType.getElementTypeBitWidth(); + int64_t tileSize = dotSShape[0] * dotSShape[1] * aShape[1]; + + // 256x256x256 (128xi8) + if (tileSize == 8388608 && aShape[0] == 256 && aShape[1] == 128 && + elemWidth == 8) { + kWidth = 16; + if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the two ping pong " + "cluster transformation"); + return; + } + + auto updateSignature = updateForOpSignature(builder); + if (llvm::failed(updateSignature)) { + LDBG("failed to update forOp signature"); + } + + // if (llvm::succeeded(updateSignature)) { + // if (llvm::failed(adjustRefinedAsyncTokens(builder))) { + // LDBG("failed to update forOp signature"); + // } + // } + + forOp->walk([](ttg::AsyncCommitGroupOp groupOp) { + auto users = groupOp.getResult().getUsers(); + if (users.empty()) { + SmallVector toDeleteVec; + for (auto token : groupOp.getInputTokens()) { + toDeleteVec.push_back(token.getDefiningOp()); + } + groupOp->erase(); + llvm::for_each(toDeleteVec, [](Operation *op) { op->erase(); }); + } + }); + } + // 128x128x512 (256xi8) + else if (tileSize == 4194304 && aShape[0] == 128 && aShape[1] == 256 && + elemWidth == 8) { + if (transformFP4s(builder, dotSOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the two ping pong " + "cluster transformation"); + return; + } + } + + addAsymmetricSyncToLoop(builder, loc); + return; + } else if (dotSOps.size() == 1) + return; + // Determine if we have a persistent GEMM. This will decide how we interpret // any memory operations that we find in conditionals. auto assumeNotTaken = isPersistentGemm(dotOps.size()); + // Compute tile size, kWidth, and mfma type. + auto dotType = dotOps[0].getType(); + auto dotShape = dotType.getShape(); + auto aType = dotOps[0].getA().getType(); + auto aShape = aType.getShape(); + auto elemWidth = aType.getElementTypeBitWidth(); + int64_t tileSize = dotShape[0] * dotShape[1] * aShape[1] * elemWidth; + + const int64_t minTile = 262144; // e.g. 32x128x64x16bit + const int64_t smallTile = 16777216; // e.g. 128x128x64x16bit + const int64_t mediumTile = 33554432; // smallTile x 2 + const int64_t largeTile = 67108864; // e.g. 256x256x64x16bit + + auto encoding = cast(aType).getEncoding(); + auto srcEncoding = cast(encoding); + kWidth = srcEncoding.getKWidth(); + auto mfmaEncoding = cast(srcEncoding.getParent()); + SmallVector intShape; + intShape.push_back(mfmaEncoding.getMDim()); + intShape.push_back(mfmaEncoding.getNDim()); + + if (dotOps.size() == 1 && useAsyncCopy) { + if (numWarps != 8) { + LDBG("Currently only support num_warp=8 for async PP"); + return; + } + if (tileSize != largeTile || aShape[1] != 64 || elemWidth != 16) { + LDBG("Only support tile size of 256x256x64 tile size for async PP"); + return; + } + + auto encoding = cast(aType).getEncoding(); + auto srcEncoding = cast(encoding); + kWidth = srcEncoding.getKWidth(); + auto mfmaEncoding = cast(srcEncoding.getParent()); + if (mfmaEncoding.getMDim() != 16 && mfmaEncoding.getNDim() != 16 && + kWidth != 8) { + LDBG("Only support 16x16 intrinsic and kWidth=8 for async PP"); + } + + DenseSet dotGlobalLoads; + DenseSet dotLocalLoads; + DenseSet dotAsyncCommitGroups; + DenseSet dotAsyncWaits; + determineDotAsyncMemoryOps(dotOps[0], dotGlobalLoads, dotLocalLoads, + dotAsyncCommitGroups, dotAsyncWaits); + if (failed(pruneDotAsyncMemoryOps(dotGlobalLoads, dotLocalLoads, + dotAsyncCommitGroups, dotAsyncWaits, + assumeNotTaken))) { + std::stringstream message; + message << "Failed to match ping pong scheduling pattern and prune async " + "memory ops."; + LDBG(message.str()); + return; + } + if (transformFourPPClusters(builder, dotOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the four ping pong " + "cluster transformation"); + return; + } + addAsymmetricSyncToLoop(builder, loc); + return; + } // The existing code depends on the loads being targeted being safe to move, // which will not hold if we do not properly have a GEMM. As a result, we // filter the associated load operations to only those that are associated @@ -715,58 +1870,11 @@ void Pingponger::getDotPingponged() { DenseSet dotLocalStores; determineDotMemoryOps(dotOps[0], dotGlobalLoads, dotLocalLoads, dotLocalStores); - - // Prune Memory operations that may be moved to only those involved in dot - // computation. To understand the "cluster assumptions" we also estimate - // the impact of any additional loads/stores. - auto gLoadIt = std::stable_partition( - gLoadOps.begin(), gLoadOps.end(), - [&dotGlobalLoads](tt::LoadOp op) { return dotGlobalLoads.contains(op); }); - auto lLoadIt = std::stable_partition(lLoadOps.begin(), lLoadOps.end(), - [&dotLocalLoads](ttg::LocalLoadOp op) { - return dotLocalLoads.contains(op); - }); - auto lStoreIt = - std::stable_partition(lStoreOps.begin(), lStoreOps.end(), - [&dotLocalStores](ttg::LocalStoreOp op) { - return dotLocalStores.contains(op); - }); - if (estimateNonDotMemoryImpact(gLoadIt, gLoadOps.end(), - assumeNotTaken) != 0) { - std::stringstream message; - message << "Unable to match ping pong scheduling pattern. Details: " - << "Non-dot global loads found in non-persistent GEMM"; - LDBG(message.str()); - return; - } - if (estimateNonDotMemoryImpact(lLoadIt, lLoadOps.end(), - assumeNotTaken) != 0) { - std::stringstream message; - message << "Unable to match ping pong scheduling pattern. Details: " - << "Non-dot local loads found in non-persistent GEMM"; - LDBG(message.str()); - return; - } - if (estimateNonDotMemoryImpact(lStoreIt, lStoreOps.end(), - assumeNotTaken) != 0) { - std::stringstream message; - message << "Unable to match ping pong scheduling pattern. Details: " - << "Non-dot local stores found in non-persistent GEMM"; - LDBG(message.str()); - return; - } - - // Remove non-dot memory operations. - gLoadOps.erase(gLoadIt, gLoadOps.end()); - lLoadOps.erase(lLoadIt, lLoadOps.end()); - lStoreOps.erase(lStoreIt, lStoreOps.end()); - // All PingPong Scheduler assumes there are 2 movable global loads and 2 - // movable local loads. - if (gLoadOps.size() != 2 || lLoadOps.size() != 2) { + if (failed(pruneDotMemoryOps(dotGlobalLoads, dotLocalLoads, dotLocalStores, + assumeNotTaken))) { std::stringstream message; - message << "Unable to match ping pong slicing pattern. Details: " - << gLoadOps.size() << " global loads in dot computation, " - << lLoadOps.size() << " local loads in dot computation"; + message << "Failed to match ping pong scheduling pattern and prune " + "memory ops."; LDBG(message.str()); return; } @@ -799,26 +1907,6 @@ void Pingponger::getDotPingponged() { // N.B., Tile size smaller than 128x128x64_FP16 is likely not compute-bound // that pingpong scheduling doesn't help much. - auto dotType = dotOps[0].getType(); - auto dotShape = dotType.getShape(); - auto aType = dotOps[0].getA().getType(); - auto aShape = aType.getShape(); - auto elemWidth = aType.getElementTypeBitWidth(); - int64_t tileSize = dotShape[0] * dotShape[1] * aShape[1] * elemWidth; - - const int64_t minTile = 262144; // e.g. 32x128x64x16bit - const int64_t smallTile = 16777216; // e.g. 128x128x64x16bit - const int64_t mediumTile = 33554432; // smallTile x 2 - const int64_t largeTile = 67108864; // e.g. 256x256x64x16bit - - auto encoding = cast(aType).getEncoding(); - auto srcEncoding = cast(encoding); - kWidth = srcEncoding.getKWidth(); - auto mfmaEncoding = cast(srcEncoding.getParent()); - SmallVector intShape; - intShape.push_back(mfmaEncoding.getMDim()); - intShape.push_back(mfmaEncoding.getNDim()); - if (numWarps == 4) { // Pingpong between warps from different blocks // Transform a loop with small tile size. // We've observed that this small tile size spent almost equivalent cycle @@ -872,14 +1960,16 @@ class TritonAMDGPUBlockPingpongPass : public TritonAMDGPUBlockPingpongBase { public: TritonAMDGPUBlockPingpongPass() = default; - TritonAMDGPUBlockPingpongPass(int32_t numStages) { + TritonAMDGPUBlockPingpongPass(int32_t numStages, bool useAsyncCopy) { this->numStages = numStages; + this->useAsyncCopy = useAsyncCopy; } void runOnOperation() override { ModuleOp m = getOperation(); for (auto funcOp : m.getOps()) { funcOp.walk([&](scf::ForOp forOp) { - Pingponger pingponger(forOp, ttg::lookupNumWarps(forOp), numStages); + Pingponger pingponger(forOp, ttg::lookupNumWarps(forOp), numStages, + useAsyncCopy); pingponger.getDotPingponged(); }); } @@ -888,6 +1978,8 @@ class TritonAMDGPUBlockPingpongPass } // namespace std::unique_ptr -mlir::createTritonAMDGPUBlockPingpongPass(int32_t numStages) { - return std::make_unique(numStages); +mlir::createTritonAMDGPUBlockPingpongPass(int32_t numStages, + bool useAsyncCopy) { + return std::make_unique(numStages, + useAsyncCopy); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 836720b43901..fbb6a71df190 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(TritonAMDGPUTransforms CanonicalizePointers.cpp CoalesceAsyncCopy.cpp ConvertToBufferOps.cpp + FourStagePipeliner.cpp OptimizeEpilogue.cpp HoistLayoutConversions.cpp ReorderInstructions.cpp diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 51d77ca942e7..733944394538 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -13,7 +13,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/OneToNTypeConversion.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -1146,6 +1146,91 @@ class ConvertConvertLayoutOp } }; +/// slice integer offset, keep base +class ConvertExtractSliceOp + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(tt::amdgpu::ExtractSliceOp extractSliceOp, + OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange remappedOperands = adaptor.getSource(); + if (remappedOperands.size() != 2) { + // some prior op materialized the fat ptr, e.g.: + // %3 = tt.bitcast %2 + // %4 = tt.splat %3 + return success(); + } + Value fatPtrBase = remappedOperands[0]; + Value fatPtrOffset = remappedOperands[1]; + if (!llvm::isa(fatPtrBase.getType())) { + return rewriter.notifyMatchFailure(extractSliceOp, + "non tt.ptr base unimplemented"); + } + auto offsetTensorTy = dyn_cast(fatPtrOffset.getType()); + if (!offsetTensorTy) { + return rewriter.notifyMatchFailure( + extractSliceOp, "non RankedTensorType offset unimplemented"); + } + + Location loc = extractSliceOp->getLoc(); + RankedTensorType resultType = extractSliceOp.getResult().getType(); + const FatPointers::FatPtrAttrs &fatPtrAttrs = + fatPtrs.at({fatPtrBase, fatPtrOffset}); + + Value newFatPtrOffset = nullptr; + auto origFatOffsetType = dyn_cast(fatPtrOffset.getType()); + auto slicedFatOffsetType = RankedTensorType::get( + resultType.getShape(), origFatOffsetType.getElementType(), + origFatOffsetType.getEncoding()); + + tt::amdgpu::ExtractSliceOp slicedFatPtrOffset = + rewriter.create( + loc, Type{slicedFatOffsetType}, Value{fatPtrOffset}, + extractSliceOp.getStaticOffsetsAttr()); + + auto newResultPtrType = + RankedTensorType::get(resultType.getShape(), fatPtrBase.getType(), + origFatOffsetType.getEncoding()); + + // Scalar case: we only need to `tt.addptr %basePtr, %offset` + if (!origFatOffsetType) { + auto addPtrOp = rewriter.create( + loc, newResultPtrType, fatPtrBase, slicedFatPtrOffset); + for (const auto &attribute : fatPtrAttrs.attributes) + addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); + newFatPtrOffset = addPtrOp.getResult(); + } + + // Tensor case: splat the scalar pointer and add the (tensor) offset: + // ``` + // %tensorBasePtr = tt.splat %basePtr + // %tensorPtr = tt.addptr %tensorBasePtr, %offset + // ``` + if (fatPtrAttrs.canNarrow) + fatPtrOffset = createTruncIOffset(rewriter, loc, fatPtrOffset, + rewriter.getI32Type()); + + tt::SplatOp tensorPtr = + rewriter.create(loc, newResultPtrType, fatPtrBase); + tt::AddPtrOp addPtrOp = rewriter.create( + loc, newResultPtrType, tensorPtr, slicedFatPtrOffset); + + for (const auto &attribute : fatPtrAttrs.attributes) + addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); + newFatPtrOffset = addPtrOp.getResult(); + + assert(newFatPtrOffset); + rewriter.replaceOp(extractSliceOp, newFatPtrOffset); + fatPtrs[{fatPtrBase, newFatPtrOffset}] = + fatPtrs.at({fatPtrBase, fatPtrOffset}); + + return success(); + } +}; + template class MaterializeFatPointer : public PointerCanonicalizationPattern { public: @@ -1508,6 +1593,8 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { target.addDynamicallyLegalDialect(isLegal); target.addDynamicallyLegalDialect(isLegal); target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect( + isLegal); // Rewrite the rest of the ops. // Note we *do not* declare unrealized_cast an illegal op here in order that @@ -1529,10 +1616,10 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { MaterializeFatPointerVariadic, MaterializeFatPointerVariadic, MaterializeFatPointerVariadic, ConvertSCFForOp, - ConvertExpandDims, ConvertSCFYieldOp, ConvertSCFIfOp, - ConvertSCFConditionOp, ConvertSCFWhileOp, ConvertCFCondBranch, - ConvertCFBranch, ConvertArithSelectOp, ConvertReturnOp>( - patterns.getContext(), opsToRewrite, fatPrs); + ConvertExpandDims, ConvertExtractSliceOp, ConvertSCFYieldOp, + ConvertSCFIfOp, ConvertSCFConditionOp, ConvertSCFWhileOp, + ConvertCFCondBranch, ConvertCFBranch, ConvertArithSelectOp, + ConvertReturnOp>(patterns.getContext(), opsToRewrite, fatPrs); if (failed(applyPartialConversion(func, target, std::move(patterns), config))) return signalPassFailure(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index 7aa0bf102ca8..3cfd5a5ccb2d 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -200,6 +200,10 @@ bool verifyNonNegativeExpr( return verifyNonSmallerByAssumption(op.getLhs(), assumptions, op.getRhs()); }) + .Case([&](auto op) { + return verifyNonNegativeExpr(op->getOperand(0), assumptions, + solver); + }) .Default([&](Operation *) { // Conservatively assume that the expression is negative LDBG(" Unhandled op, cannot assume non-negative"); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp new file mode 100644 index 000000000000..4c515ba308a6 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp @@ -0,0 +1,946 @@ +#include "FourStagePipeliner.h" +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create stream operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop and epilogue. +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +#define DEBUG_TYPE "tritonamdgpu-four-stage-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +static Operation *streamPredication(RewriterBase &rewriter, Operation *op, + Value pred) { + // The epilogue peeling generates a select for the stage output. This causes + // too much register pressure with the loop result and the epilogue-dot in + // regs for the select. Conditionally executing the dot will allow the backend + // to optimize the select away as redundant. + if (auto dotOp = dyn_cast(op)) { + auto loc = dotOp->getLoc(); + auto ifOp = rewriter.create(loc, dotOp->getResult(0).getType(), + pred, /*withElseRegion=*/true); + auto thenB = ifOp.getThenBodyBuilder(); + auto yield = thenB.create(loc, dotOp->getResult(0)); + dotOp->moveBefore(yield); + ifOp.getElseBodyBuilder().create(loc, dotOp->getOperand(2)); + return ifOp; + } + return tt::predicateOp(rewriter, op, pred); +} + +FourStagePipeliner::FourStagePipeliner(scf::ForOp _forOp, int _numStages, + int _globalPrefetch, int _localPrefetch, + bool _useAsyncCopy) + : forOp(_forOp), numStages(_numStages), numBuffers(1), + useAsyncCopy(_useAsyncCopy), schedule(numStages), + axisInfoAnalysis(forOp->getParentOfType()) { + int lastStage = numStages - 1; + stages[SCHED_GLOBAL_LOAD] = 0; + stages[SCHED_LOCAL_STORE] = _globalPrefetch; + stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch; + stages[SCHED_COMPUTE] = lastStage; + stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; + + options.supportDynamicLoops = true; + options.peelEpilogue = true; + options.predicateFn = streamPredication; +} + +bool FourStagePipeliner::checkPrecondition(scf::ForOp forOp, int numStages) { + unsigned dotCount{}; + unsigned reduceCount{}; + + if (tt::getNumStagesOrDefault(forOp, numStages) != 4) + return false; + + if (!forOp.getBody()) + return false; + + for (auto &op : *forOp.getBody()) { + if (isa(op)) { + dotCount++; + } else if (isa(op)) { + reduceCount++; + } + } + return dotCount == 2 && reduceCount == 2; +} + +// Init Schedule Config based on settings and loop characteristics. +// Create clusters in order of ops in loop. This can interleave ops +// from different stages in the same cluster to achieve better backend +// scheduling. +// WARNING: Changing the order of schedule.clusters.newAtBack() calls +// can cause invalid schedules to be produced. +LogicalResult FourStagePipeliner::initSchedule(int maxIndirectionLevel) { + bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; + stages[SCHED_LOCAL_STORE] += maxIndirectionLevel; + + LDBG( + "Stage schedule:" << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] + << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] + << ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD] + << ", COMPUTE stage = " << stages[SCHED_COMPUTE] + << ", ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT] + << "; total = " << numStages); + + if (stages[SCHED_LOCAL_STORE] >= numStages || + stages[SCHED_LOCAL_STORE] > stages[SCHED_LOCAL_LOAD]) { + LDBG("Invalid stage schedule"); + return failure(); + } + + // Calculate the number of buffers needed for each load. + // TODO: Use the precise number of buffers needed by the particular load. + numBuffers = + std::max(1, stages[SCHED_LOCAL_LOAD] - stages[SCHED_LOCAL_STORE]); + // If we use AsyncCopy we need one more buffer since we are not using a + // register buffer + if (useAsyncCopy) { + numBuffers += 1; + } + numBuffers = 2; + + LDBG("deduced max shared memory buffer number = " << numBuffers); + + // We place async wait as the first cluster because we want to have it being + // the first in the main loop after pipelining. + int asyncWaitCluster = 0; + + // If tt.load and ttg.local_store are in the same stage + // spread them apart to allow overlap with compute + // else + // Initiate ttg.local_store before tt.load + int globalLoadCluster = 1; + int localStoreCluster = 3; + if (!pairedGlobalLoadLocalStore) { + globalLoadCluster = 3; + localStoreCluster = 2; + } + + // If ttg.local_load and ttg.local_store are in the same stage + // spread them apart to allow overlap with compute + // else if they share the buffer + // ttg.local_load must come first + // else + // schedule ttg.local_load in the middle + int localLoadCluster = globalLoadCluster; + if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_LOCAL_STORE]) { + localLoadCluster = std::max(3, localStoreCluster + 1); + } else if (numBuffers == 1 && localLoadCluster >= localStoreCluster) { + // For 1 buffer, ttg.local_load must occur before ttg.local_store + localLoadCluster = localStoreCluster - 1; + } + + // Schedule compute with ttg.local_load if paired + // otherwise, schedule in the middle + int computeCluster = 2; + if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_COMPUTE]) { + computeCluster = localLoadCluster; + } + + // Create clusters in order of 4-stage pipeliner. You can swap lines below to + // change the schedule of the loop. Not all combination are valid, e.g. if a + // consumer and producer from the same stage are in the wrong cluster order + // the loop expander will silently fail + + // DOT1 + dotClusters[0] = schedule.clusters.newAtBack(); + // SM2, + softmaxClusters[0] = schedule.clusters.newAtBack(); + // Wait for V, LRV + localReadClusters[0] = schedule.clusters.newAtBack(); + // ACK + asyncCopyClusters[0] = schedule.clusters.newAtBack(); + // DOT2 + dotClusters[1] = schedule.clusters.newAtBack(); + // SM1 + softmaxClusters[1] = schedule.clusters.newAtBack(); + // Wait for K, LRK + localReadClusters[1] = schedule.clusters.newAtBack(); + // ACV + asyncCopyClusters[1] = schedule.clusters.newAtBack(); + + // ATTENTION 4-stage (not used) + clusters[SCHED_GLOBAL_LOAD] = softmaxClusters[1]; + clusters[SCHED_LOCAL_STORE] = asyncCopyClusters[0]; + clusters[SCHED_LOCAL_LOAD] = asyncCopyClusters[0]; + clusters[SCHED_ASYNC_WAIT] = asyncCopyClusters[0]; + clusters[SCHED_COMPUTE] = softmaxClusters[0]; + // Make assignments + // std::array clusterVec; + // std::generate(clusterVec.begin(), clusterVec.end(), + // [&]() { return schedule.clusters.newAtBack(); }); + + // clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; + // clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; + // clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; + // clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; + // clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster]; + + LDBG("Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster + << ", LOCAL_STORE cluster = " << localStoreCluster + << ", LOCAL_LOAD cluster = " << localLoadCluster + << ", COMPUTE cluster = " << computeCluster + << ", ASYNC_WAIT cluster = " << asyncWaitCluster + << "; total = " << SCHED_SIZE); + + return success(); +} + +bool FourStagePipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, + Value extractIdx) { + assert(useAsyncCopy); + // If we have a single buffer we would require another barrier after the + // local_reads so instead we fall back to pipeline with registers + // Removing this check will create incorrect IR, see + // MembarUtility.h:membarFilter + if (numBuffers == 1) + return false; + + OpBuilder builder(loadOp); + Location loc = loadOp.getLoc(); + + Value src = loadOp.getPtr(); + auto srcTy = cast(src.getType()); + + ttg::MemDescType allocTy = cast(alloc.getType()); + auto sharedEncodingAttr = + cast(allocTy.getEncoding()); + + // Extract local subview from shared allocation + Value zero = builder.create(forOp.getLoc(), 0, 32); + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + auto subviewTy = ttg::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + + // If the load is used by an existing local allocation we replace it with the + // new subview + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + tt::replaceUsesAndPropagateType(builder, alloc, viewLoad); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) + alloc.erase(); + + auto copyOp = builder.create( + loadOp.getLoc(), src, viewLoad, loadOp.getMask(), loadOp.getOther(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + + // Insert synchronization primitives to create barriers during lowering + auto commitOp = + builder.create(loc, copyOp->getResult(0)); + + ttg::AsyncWaitOp waitOp = + builder.create(loc, commitOp->getResult(0), 0); + + // Create local load which consumes the async token from the AsyncWait + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad, waitOp); + + auto [loadStage, loadCluster] = schedule[loadOp]; + // Schedule new ops + schedule.insert(copyOp, loadStage, loadCluster); + // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the + // later UpdateAsyncWaitCount pass can deduce better waitcnts + schedule.insert(commitOp, loadStage, loadCluster); + // If the LocalLoads are scheduled to a later stage than AsyncCopy we need to + // place the AsyncCopy prefetches after the AsyncWaits which create a barrier + // to ensure all warps are finished reading the shared buffer we will write + // into. This is done by scheduling AsyncWait as the first cluster. + // If AsyncCopy and LocalLoads are in the same stage we do not assign a + // schdule so they are placed before the LocalLoads + // Disable for FA + // if (loadStage != stages[SCHED_LOCAL_LOAD]) + // scheduleOp(waitOp, SCHED_ASYNC_WAIT); + + // if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) + // scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); + + loadOp->replaceAllUsesWith(ValueRange{sharedLoad}); + + // 4-stage pipeliner scheduleing + auto localLoadStage = loadStage == 0 ? 1 : 3; + auto localLoadCluster = loadStage == 0 ? 1 : 0; + schedule.insert(sharedLoad, localLoadStage, + localReadClusters[localLoadCluster]); + schedule.insert(waitOp, localLoadStage, localReadClusters[localLoadCluster]); + + // Make sure that a possible cvt is in the same stage or otherwise it will not + // get folded + if (sharedLoad->hasOneUse()) { + if (auto cvt = + dyn_cast(*sharedLoad->getUsers().begin())) { + LDBG("Change cvt layout stage and cluster"); + schedule.insert(cvt, localLoadStage, localReadClusters[localLoadCluster]); + } + } + + if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && + sharedLoad->hasOneUse()) { + if (auto cvt = + dyn_cast(*sharedLoad->getUsers().begin())) + scheduleOp(cvt, SCHED_LOCAL_LOAD); + } + + // Delete old loadOp + schedule.erase(loadOp); + loadOp.erase(); + return true; +} + +void FourStagePipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, + Value extractIdx) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + + ttg::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + Operation *copy = builder.clone(*loadOp); + + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + auto subviewTy = ttg::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + // Clean up old local caches. + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + tt::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) + alloc.erase(); + + // Prefetch load ahead of the dot stage if is used by the dot. + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); + scheduleOp(viewLoad, SCHED_LOCAL_STORE); + scheduleOp(storeOp, SCHED_LOCAL_STORE); + + // Create local load + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad); + Value result = sharedLoad.getResult(); + if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) + scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); + + // If the currently processed `LoadOp` is labeled with an index regarding + // to which `DotOp` operand the corresponding data belongs to, then label the + // expanded `LocalStoreOp` with the same index. This is required for + // instruction scheduling hints to correctly count the emitted `ds_write` + // instructions for each GEMM tile. + if (auto attr = loadOp->getAttr(tt::amdgpu::OpIdxAttr::getMnemonic())) { + storeOp->setAttr(tt::amdgpu::OpIdxAttr::getMnemonic(), attr); + } + + loadOp->replaceAllUsesWith(ValueRange{result}); + + if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && result.hasOneUse()) { + if (auto cvt = dyn_cast(*result.getUsers().begin())) + scheduleOp(cvt, SCHED_LOCAL_LOAD); + } + + loadOp.erase(); +} + +// Returns the given |inputValue|'s dot user result encoding and updates |opIdx| +// with which dot operand |inputValue| is fed into if possible. +static ttg::AMDMfmaEncodingAttr getDotEncoding(Value inputValue, + unsigned *opIdx) { + if (!llvm::hasSingleElement(inputValue.getUses())) + return nullptr; + + Operation *user = *inputValue.getUsers().begin(); + if (user->getNumResults() != 1 || + user->getBlock() != inputValue.getParentBlock()) + return nullptr; + + if (auto dotOp = dyn_cast(user)) { + OpOperand &use = *inputValue.getUses().begin(); + *opIdx = use.getOperandNumber(); + auto dotType = cast(dotOp->getResult(0).getType()); + return dyn_cast(dotType.getEncoding()); + } + return getDotEncoding(user->getResult(0), opIdx); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreDotEnc(Value loadedValue) { + ttg::SwizzledSharedEncodingAttr attr; + for (Operation *user : loadedValue.getUsers()) { + LDBG(" getSharedEncIfAllUsersAreDotEnc current user: " << *user); + if (user->getNumResults() != 1) + return std::nullopt; + + ttg::SwizzledSharedEncodingAttr tempAttr; + Value userResult = user->getResult(0); + Type userResType = userResult.getType(); + if (auto memDesc = dyn_cast(userResType)) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(userResult).has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + + auto srcTy = cast(loadedValue.getType()); + auto ctaLayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = getOrderForMemory(srcTy); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + SmallVector sharedOrder; + int rank = order.size(); + // TODO rework this when shared -> dotOperand conversions support + // arbitrary shared memory ordering + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be + // the slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (order[i] != 0) + sharedOrder.emplace_back(order[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = order; + } + + auto userResEnc = cast(userResType).getEncoding(); + if (auto dotOpEnc = dyn_cast(userResEnc)) { + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + loadedValue.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, + ctaLayout, bitWidth, /*needTrans=*/false); + } else if (auto llEnc = dyn_cast(userResEnc)) { + // We use linear layout directly for scaled dot fp8 operands. For such + // cases, we need to look further down the def-use chain to find the dot + // op for the mfma layout to deduce operand index and other information. + unsigned opIdx; + if (auto dotEnc = getDotEncoding(userResult, &opIdx)) { + unsigned vecSize = llEnc.getLinearLayout().getNumConsecutiveInOut(); + LDBG("deduced opIdx: " << opIdx << "; deduced vecSize: " << vecSize); + tempAttr = dotEnc.composeSharedLayoutForOperand( + ctaLayout, opIdx, srcTy.getShape(), order, vecSize, bitWidth, + /*needTrans=*/false); + } + } + } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; + } + return attr; +} + +// Create a map from load ops to their indirection levels and the final uses +// of the load op (another load op, or a dot op). +// +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +void FourStagePipeliner::computeLoadOpsToIndirectionLevelAndUse() { + DenseSet seen; + + // Recursively visit the given op and its operands to discover all load ops + // and collect their indirection levels and uses. + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + // Skip previously visited load ops. + if (!seen.insert(op).second) + return; + + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.emplace_back(op, distance, use); + use = op; + ++distance; + } + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } +} + +// Goes through all load ops to identify those that can be pipelined and assign +// layout to them. +void FourStagePipeliner::assignMemoryLayouts() { + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(op)) + // TODO: We'd need to verify that the distance is the same. + continue; + + auto loadOp = cast(op); + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) { + LDBG("Skip non-tensor load " << loadOp); + continue; + } + + auto pointeeTy = + cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * pointeeTy.getIntOrFloatBitWidth(); + + LDBG("assign memory layouts (width=" << width << ") for load " << loadOp); + LoadInfo loadInfo; + if (isa(use)) { + // Only use shared memory when feeding into a dot op. + loadInfo.usedByDot = true; + // If the max continugous bits we can read is < 32, buffer in registers. + if (width >= 32) { + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + } + } else if (auto useOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadToInfo already, it means that the use is not valid for pipelining + // for some reason. We should skip this loadOp, too. + // + // Note that we have an assumption that the use of this loadOp has already + // be processed in a previous loop iteration. This assumption is held by + // how loadOpsToIndirectionLevelAndUse recursively collects + // loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(useOp) == 0) { + continue; + } + } + + loadToInfo[op] = loadInfo; + } +} + +LogicalResult +FourStagePipeliner::scheduleLoads(DenseSet &rootUsers) { + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + computeLoadOpsToIndirectionLevelAndUse(); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return failure(); + + // Check which loads are good for pipelining, and assign them memory layouts. + assignMemoryLayouts(); + if (loadToInfo.empty()) + return failure(); + + // Filter out load ops that cannot be pipelined. + int resize = 0; + for (int i = 0, e = loadOpToIndLevelAndUse.size(); i < e; ++i) { + auto [loadOp, distance, use] = loadOpToIndLevelAndUse[i]; + if (loadToInfo.count(loadOp) != 0) + loadOpToIndLevelAndUse[resize++] = loadOpToIndLevelAndUse[i]; + } + loadOpToIndLevelAndUse.resize(resize); + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + + LDBG("maxIndirectionLevel = " << maxIndirectionLevel); + if (maxIndirectionLevel >= numStages) + return failure(); + + if (failed(initSchedule(maxIndirectionLevel))) + return failure(); + + // The stage gap between chained loads--this allows us to "spread" loads + // with a non-one step in case the number of stages given by the user is + // large. + assert(numStages >= 2 && "requires num_stages=2 at least"); + unsigned stagesBetweenLoads = + llvm::divideCeil(numStages - 2, maxIndirectionLevel + 1); + LDBG("stagesBetweenLoads = " << stagesBetweenLoads); + + // Assign stages to the loads. + // FA: + // Load1: Stage=0, cluster=1 + // Load2: Stage=1, cluster=3 + int i{}; + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + if (schedule.count(loadOp) > 0) + continue; + schedule.insert(loadOp, i, asyncCopyClusters[i == 0 ? 0 : 1]); + i++; + } + + // Put the root uses of the loads in the last stage. + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). + if (!isa(use)) { + auto loadStage = schedule[loadOp].first; + schedule.insert(use, loadStage + 2, dotClusters[loadStage == 0 ? 0 : 1]); + // scheduleOp(use, SCHED_COMPUTE); + rootUsers.insert(use); + } + } + + // Calculate distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + LLVM_DEBUG({ + LDBG("Chosen loads to pipeline:"); + for (const auto &[load, info] : loadToInfo) { + LDBG(" - load: " << *load); + LDBG(" distToUse: " << info.distToUse); + LDBG(" usedByDot: " << info.usedByDot); + } + }); + + return success(); +} + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +void FourStagePipeliner::scheduleDependencies() { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; ++stage) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + auto depCluster = cluster; + bool override = false; + if (llvm::isa(op) && stage == 3) { + depCluster = softmaxClusters[0]; + override = true; + } + + auto moveStages = [this, stage, cluster = cluster, + depCluster = depCluster, override](Operation *op) { + if (llvm::isa(op)) { + return std::make_pair(stage, cluster); + } + return std::make_pair(stage, depCluster); + }; + schedule.insertDepsOfOp(op, false, false, moveStages); + } + } +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +void FourStagePipeliner::scheduleDistanceOneDependencies() { + auto getNestedOperands = [](Operation *op) { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap + dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + auto arg = dyn_cast(operand); + if (!arg || arg.getArgNumber() == 0 || arg.getOwner() != op.getBlock()) + continue; + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (!defOp || schedule.count(defOp) != 0) + continue; + if (isa(defOp)) { + // Exception: schedule loads with a distance of 1 together with the + // current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], true); + } + } + } +} + +void FourStagePipeliner::scheduleRemainingToLastStage() { + int lastStage = numStages - 1; + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + auto cluster = clusters[SCHED_COMPUTE]; + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + opToCluster[&op] = cluster; + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == lastStage) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; + tt::CoarseSchedule::Cluster opCluster = schedule[op].second; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, lastStage, cluster); + } +} + +// Create an allocation that can hold distance number of loadOp shapes. +Value FourStagePipeliner::createAlloc( + Operation *loadOp, ttg::SwizzledSharedEncodingAttr sharedEnc) { + OpBuilder builder(forOp); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), numBuffers); + Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + auto alloc = builder.create(loadOp->getLoc(), memdescType); + sharedMemAllocs.push_back(alloc); + return alloc; +} + +// Convert load ops into shared memory allocation loads and apply +// multi-buffering based on the required number of buffers. +void FourStagePipeliner::createStreamOps() { + SmallVector> loadToAllocs; + for (auto &[loadOp, info] : loadToInfo) { + if (!info.sharedEncoding || info.isAsync) + continue; + + Value alloc = createAlloc(loadOp, info.sharedEncoding); + assert(alloc && "Failed to create alloc for the async load."); + loadToAllocs.emplace_back(loadOp, alloc); + } + + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); + + Location loc = forOp.getLoc(); + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + (void)addIterArgsToLoop(builder, forOp, {extractIdx}); + + // Create one counter for the extract indices to avoid creating long + // live range. + extractIdx = forOp.getBody()->getArgument(newOperandIndex); + + builder.setInsertionPoint(forOp.getBody(), forOp.getBody()->begin()); + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + // Replace tt.loads with async copies or stream copies + for (auto &[op, alloc] : loadToAllocs) { + if (auto loadOp = dyn_cast(op)) { + if (useAsyncCopy && createAsyncCopy(loadOp, alloc, extractIdx)) + continue; + createStreamCopy(loadOp, alloc, extractIdx); + } + } + // Patch the yield with the updated counters. + appendToForOpYield(forOp, {extractIdx}); +} + +LogicalResult FourStagePipeliner::preprocessLoopAndBuildSchedule() { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + if (failed(scheduleLoads(rootUsers))) + return failure(); + if (loadToInfo.empty()) + return failure(); + + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + schedule.dump(); + }); + + // Convert the loads into shared memory allocations and loads from them. + createStreamOps(); + LLVM_DEBUG({ + LDBG("Coarse schedule with replaced laod ops:"); + schedule.dump(); + }); + + // Schedule reductions + int c = 2; + for (auto reduceOp : forOp.getBody()->getOps()) { + schedule.insert(reduceOp, c, softmaxClusters[c == 2 ? 1 : 0]); + c++; + } + + for (auto exp2Op : forOp.getBody()->getOps()) { + schedule.insert(exp2Op, 2, softmaxClusters[1]); + } + LLVM_DEBUG({ + LDBG("Coarse schedule after schedule reduction:"); + schedule.dump(); + }); + + scheduleDependencies(); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + schedule.dump(); + }); + + scheduleDistanceOneDependencies(); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + schedule.dump(); + }); + + scheduleRemainingToLastStage(); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + schedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> coarseSchedule = + schedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [coarseSchedule](scf::ForOp, + std::vector> &s) { + s = std::move(coarseSchedule); + }; + + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + // Explicitly deallocate created allocations. + for (auto alloc : sharedMemAllocs) + builder.create(forOp.getLoc(), alloc); + + return success(); +} + +LogicalResult FourStagePipeliner::pipelineLoop() { + if (failed(preprocessLoopAndBuildSchedule())) + return failure(); + LDBG("Loop before sending to expander:\n" << *forOp); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + return tt::pipelineForLoop(rewriter, forOp, options); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.h b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.h new file mode 100644 index 000000000000..dd01de342cdb --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.h @@ -0,0 +1,168 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTRANSFORMS_FOURSTAGEPIPELINE_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTRANSFORMS_FOURSTAGEPIPELINE_H_ + +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Software pipelining generally works by anchoring on global load ops in the +// main loop and rotating the loop to schedule global load ops for future loop +// iterations together with compute for the current iteration. In this way, we +// can 1) issue memory operations earlier to hide the latency and 2) break the +// strong dependency inside on loop iteration to give backends flexibility to +// better interleave instructions for better instruction-level parallelism. +// +// This FourStagePipeliner class creates the pipelining schedule and calls the +// PipelineExpander to rewrite the `scf.for` loop accordingly. A schedule +// consists of multiple stages, where ops from different stages can overlap +// executions because the dependencies are loop carried. +// +// The general flow of this process is: +// +// 1. The user provides a `num_stages` that specifies how many stages the +// pipeline will have. The number of stages must be larger than the distance +// from the first independent load to the compute in order to pipeline. +// 1.a. User may also specify `global_prefetch=` to set the number of +// stages between tt.load and ttg.local_store ops. +// 1.b. User may also specify `local_prefetch=` to set the number of +// stages between ttg.local_load and compute. +// 2. A schedule is created based on the distance between the global loads +// in the first stages and the compute that uses the loaded values in the +// last stage (num_stages - 1). Each operation will be clustered in the +// order to best overlap with other operations (see details below in the +// initSchedule method). +// 3. When the compute is a tt.dot, the scheduler will insert a shared +// memory allocation between the global load and tt.dot. The ttg.local_store +// will save the global load value to shared memory and the ttg.local_load +// will load the relevant tiles for the tt.dot. These operations will be +// scheduled according to various scheduling schemes outlined below in the +// initSchedule method (see details there). +// 4. Finally the schedule will be passed to the PipelineExpander to rewrite +// accordingly. The new implementation will consist of: +// a. Prologue: containing the ramp-up of num_stages-1 stages for +// iteratorions i=[0, num_stages-1). +// b. New loop: ordered by cluster and iterated on each operation by +// `i + (num_stages-op_stage)`. +// c. Epilogue: ramp-down of the last `num_stages-1` iterations for the +// ops in stages 1 to last_stage. This must consider that the loop +// bounds may be shorter than num_stages. In this case, the epilogue +// iterations must align with the prologue. +// +class FourStagePipeliner { + // Define categories of scheduling details per Operation types. + // The FourStagePipeliner schedules 5 types of operations: + // 1. GLOBAL_LOAD: tt.load / ttg.async_copy_global_to_local + // 2. LOCAL_STORE: ttg.local_store + // 3. LOCAL_LOAD: ttg.local_load + // 4. COMPUTE: ops that use the loaded data + // 5. ASYNC_WAIT: ttg.async_wait + // Note that ttg ops mentioned in the above list are created in this pass. + enum SchedType { + SCHED_GLOBAL_LOAD, + SCHED_LOCAL_STORE, + SCHED_LOCAL_LOAD, + SCHED_COMPUTE, + SCHED_ASYNC_WAIT, + SCHED_SIZE + }; + +public: + FourStagePipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch, + int _localPrefetch, bool _useAsyncCopy); + + static bool checkPrecondition(scf::ForOp forOp, int numStages); + + LogicalResult pipelineLoop(); + +private: + LogicalResult initSchedule(int maxIndirectionLevel); + + void computeLoadOpsToIndirectionLevelAndUse(); + void assignMemoryLayouts(); + LogicalResult scheduleLoads(DenseSet &rootUsers); + void scheduleDependencies(); + void scheduleDistanceOneDependencies(); + void scheduleRemainingToLastStage(); + + LogicalResult preprocessLoopAndBuildSchedule(); + + Value createAlloc(Operation *loadOp, + triton::gpu::SwizzledSharedEncodingAttr sharedEnc); + bool createAsyncCopy(triton::LoadOp loadOp, Value alloc, Value extractIdx); + void createStreamCopy(triton::LoadOp loadOp, Value alloc, Value extractIdx); + void createStreamOps(); + + void scheduleOp(Operation *op, SchedType type, int stage = -1) { + if (stage < 0) + stage = stages[type]; + schedule.insert(op, stage, clusters[type]); + } + +private: + // Data members + scf::ForOp forOp; + + // User settings + int numStages; + + // Computed number of buffers + int numBuffers; + + // Directly store to shared memory with AsyncCopy when pipelining tt.loads + bool useAsyncCopy; + + // Stage for each SchedType Op + int stages[SCHED_SIZE]; + // (not used anymore) Cluster for each SchedType Op + std::array clusters; + + // Clusters to hold the different Ops for the 4-stage pipeliner + std::array localReadClusters; + std::array softmaxClusters; + std::array asyncCopyClusters; + std::array dotClusters; + + // Scheduling clusters + triton::CoarseSchedule schedule; + + // Mapping and indirection level for each `tt.load` to its use. + SmallVector> loadOpToIndLevelAndUse; + + struct LoadInfo { + // Shared layout is used for loads feeding into dot ops. + triton::gpu::SwizzledSharedEncodingAttr sharedEncoding = nullptr; + // The distance of this load's stage to its use' stage. + int distToUse = 0; + bool usedByDot = false; + bool isAsync = false; + }; + + // Mapping for each pipelined load to scheduling details. + llvm::MapVector loadToInfo; + + // Lookup alignment/contiguity mappings for the current module. + triton::ModuleAxisInfoAnalysis axisInfoAnalysis; + + // Capture list of new shared memory buffers. + SmallVector sharedMemAllocs; + + // Pipelining options for the PipelineExpander + triton::PipeliningOption options; +}; + +#endif diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp index a14e42d2d93f..0c3bb3e44966 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp @@ -59,27 +59,23 @@ bool isOneOperandElementwiseOp(Operation *op) { return false; } -static triton::StoreOp convertMfmaLayoutForCDNA4(PatternRewriter &rewriter, - Value ptr, Value val, - Value mask, - triton::StoreOp oldStOp) { +// Tries to optimize oldStoreOp with v_permlane*_swap instruction when possible. +// Returns null store op if not suitable. +static triton::StoreOp +usePermlaneSwapToOptimizeStore(PatternRewriter &rewriter, Value ptr, Value val, + Value mask, triton::StoreOp oldStoreOp) { auto ptrType = cast(ptr.getType()); auto valType = cast(val.getType()); - auto mfmaLayout = - cast(valType.getEncoding()); - // Create a new layout where each thread holds 8 consecutive elements, in // order to enable wide 128-bit global stores. - std::optional mfma8Layout = + std::optional storeLL = triton::gpu::chooseMfmaLikeStoreLayout(valType); + if (!storeLL) + return nullptr; - if (!mfma8Layout) - return rewriter.create(oldStOp.getLoc(), ptr, val, mask, - oldStOp.getCache(), - oldStOp.getEvict()); Attribute newEncoding = triton::gpu::LinearEncodingAttr::get( - mfmaLayout.getContext(), mfma8Layout.value()); + oldStoreOp.getContext(), storeLL.value()); auto newPtrType = RankedTensorType::get( ptrType.getShape(), ptrType.getElementType(), newEncoding); Value newPtr = rewriter.create(ptr.getLoc(), @@ -99,9 +95,9 @@ static triton::StoreOp convertMfmaLayoutForCDNA4(PatternRewriter &rewriter, newMaskType, mask); } - return rewriter.create(oldStOp.getLoc(), newPtr, newVal, - newMask, oldStOp.getCache(), - oldStOp.getEvict()); + return rewriter.create(oldStoreOp.getLoc(), newPtr, newVal, + newMask, oldStoreOp.getCache(), + oldStoreOp.getEvict()); } // convert(val) : xmma -> blocked @@ -195,12 +191,9 @@ class BypassEpilogueSMEM : public mlir::OpRewritePattern { newMask = rewriter.create( mask.getLoc(), newMaskType, mask); } - triton::StoreOp newStoreOp; - if (auto mfmaLayout = - dyn_cast(newEncoding)) { - newStoreOp = - convertMfmaLayoutForCDNA4(rewriter, newPtr, newVal, newMask, stOp); - } else { + triton::StoreOp newStoreOp = + usePermlaneSwapToOptimizeStore(rewriter, newPtr, newVal, newMask, stOp); + if (!newStoreOp) { newStoreOp = rewriter.create( stOp.getLoc(), newPtr, newVal, newMask, stOp.getCache(), stOp.getEvict()); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index fc9aa0a7cafa..14a35f1e61ef 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -1,3 +1,4 @@ +#include "FourStagePipeliner.h" #include "TritonAMDGPUTransforms/Passes.h" #include "mlir/Support/LLVM.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" @@ -10,6 +11,7 @@ #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" @@ -119,9 +121,11 @@ class StreamPipeliner { public: StreamPipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch, - int _localPrefetch, bool _useAsyncCopy) + int _localPrefetch, bool _useAsyncCopy, + bool _useF16BlockPingpong, bool _useAsyncCopyOverlap) : forOp(_forOp), numStages(_numStages), numBuffers(1), - useAsyncCopy(_useAsyncCopy), schedule(numStages), + useAsyncCopy(_useAsyncCopy), useF16BlockPingpong(_useF16BlockPingpong), + useAsyncCopyOverlap(_useAsyncCopyOverlap), schedule(numStages), axisInfoAnalysis(forOp->getParentOfType()) { int lastStage = numStages - 1; stages[SCHED_GLOBAL_LOAD] = 0; @@ -174,6 +178,12 @@ class StreamPipeliner { // Directly store to shared memory with AsyncCopy when pipelining tt.loads bool useAsyncCopy; + // Whether or not we are intend to ping-pong. + bool useF16BlockPingpong; + + // Move AsyncCopy before AsyncWait. + bool useAsyncCopyOverlap; + // Stage for each SchedType Op int stages[SCHED_SIZE]; // Cluster for each SchedType Op @@ -219,6 +229,15 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; stages[SCHED_LOCAL_STORE] += maxIndirectionLevel; + // In useAsyncCopy + PingPong case, we'd want to hoist out first async_wait + // out of the loop, and async_wait within the loop be towards the end. + // This is beneficial for maximizing hiding of latency, while ensuring + // 2 barriers between asyncWait and localLoad at start of loop S.T + // we do not hit race conditions between warp-lo and warp-hi. + if (useAsyncCopy && useF16BlockPingpong) { + stages[SCHED_ASYNC_WAIT] = std::max(0, stages[SCHED_LOCAL_LOAD] - 1); + } + LDBG( "Stage schedule:" << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] @@ -246,9 +265,9 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { LDBG("deduced max shared memory buffer number = " << numBuffers); // We place async wait as the first cluster because we want to have it being - // the first in the main loop after pipelining. - int asyncWaitCluster = 0; - + // the first in the main loop after pipelining. However if we intend on doing + // PP then we set it near the end of the loop for reasons state above. + int asyncWaitCluster = useF16BlockPingpong ? 4 : 0; // If tt.load and ttg.local_store are in the same stage // spread them apart to allow overlap with compute // else @@ -281,6 +300,14 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { computeCluster = localLoadCluster; } + if (useAsyncCopyOverlap) { + globalLoadCluster = 0; + localStoreCluster = 1; + asyncWaitCluster = 2; + localLoadCluster = 3; + computeCluster = 3; + } + // Make assignments std::array clusterVec; std::generate(clusterVec.begin(), clusterVec.end(), @@ -1052,6 +1079,13 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { return signalPassFailure(); } + // TODO: Replace this with more stable argument/env, once we unify strategy + // between MXFP4 and FP16. + bool useF16BlockPingpong = + triton::tools::getBoolEnv("TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG"); + bool useAsyncCopyOverlap = + triton::tools::getBoolEnv("TRITON_HIP_ASYNC_COPY_OVERLAP") & + useAsyncCopy; SmallVector loops; getOperation()->walk([&](scf::ForOp forOp) { labelLoadOpsForTritonDot(forOp); @@ -1063,12 +1097,24 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { for (scf::ForOp forOp : loops) { if (!checkPrecondition(forOp)) continue; - StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), - globalPrefetch, localPrefetch, useAsyncCopy); - (void)sp.pipelineLoop(); + + if (FourStagePipeliner::checkPrecondition(forOp, numStages)) { + FourStagePipeliner fsp(forOp, + tt::getNumStagesOrDefault(forOp, numStages), + globalPrefetch, localPrefetch, useAsyncCopy); + (void)fsp.pipelineLoop(); + } else { + StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), + globalPrefetch, localPrefetch, useAsyncCopy, + useF16BlockPingpong, useAsyncCopyOverlap); + (void)sp.pipelineLoop(); + } } - if (useAsyncCopy) { + // This removes additional barrier but pingpong will add the barrier again. + // So we should just not do it to get a better vmcnt in front of each + // AsyncCopy. + if (useAsyncCopy && !useF16BlockPingpong) { llvm::SmallSetVector waitOps; moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); }); tt::combineRedundantWaitOps(waitOps); diff --git a/third_party/amd/python/test/test_extract_slice.py b/third_party/amd/python/test/test_extract_slice.py deleted file mode 100644 index c52d5d3a6e5a..000000000000 --- a/third_party/amd/python/test/test_extract_slice.py +++ /dev/null @@ -1,112 +0,0 @@ -import pytest -import torch - -import triton - -from triton._internal_testing import is_hip - -num_ctas_list = [1] - -GPU_DIALECT = "ttg" - -if is_hip(): - THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size -else: - THREADS_PER_WARP = 32 - - -class BlockedLayout: - - def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): - self.sz_per_thread = size_per_thread - self.threads_per_warp = threads_per_warp - self.warps_per_cta = warps_per_cta - self.order = order - self.ctas_per_cga = ctas_per_cga - self.cta_split_num = cta_split_num - self.cta_order = cta_order - - def __str__(self): - return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" - - -# ----------------------- -# test extract slice -# ----------------------- - -extract_layout = [ - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - # FIXME(Lezcano): This layout errors out - #BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), -] -blocked_layout = [ - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), -] - - -@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", - [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("extract_layout", extract_layout) -@pytest.mark.parametrize("blocked_layout", blocked_layout) -def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, - extract_layout, device='cuda'): - if not is_hip(): - pytest.skip("extract_slice is AMD specific instruction.") - - ir = f""" - #blocked = {blocked_layout} - #extract_layout = {extract_layout} - module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {str(64)} : i32}} {{ - tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> - %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> - %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> - %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> - %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> - %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> - %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> - %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> - %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> - %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> - %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> - %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> - %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> - %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> - %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> - %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> - %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> - %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> - %12 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> - %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #extract_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> - %14 = ttg.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> - %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> - tt.return - }} - }} - """ - x = torch.randn((M, N), device=device, dtype=torch.float16) - import tempfile - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) - - extract_slice = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) - - kernel[(1, 1, 1)](x.data_ptr(), extract_slice) - test_result = torch.equal(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], - extract_slice) - assert test_result diff --git a/third_party/amd/python/test/test_extract_slice_concat_op.py b/third_party/amd/python/test/test_extract_slice_concat_op.py new file mode 100644 index 000000000000..b403a69ebf29 --- /dev/null +++ b/third_party/amd/python/test/test_extract_slice_concat_op.py @@ -0,0 +1,227 @@ +import pytest +import torch + +import triton + +from triton._internal_testing import is_hip + +num_ctas_list = [1] + +GPU_DIALECT = "ttg" + +if is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +class LinearLayout: + + def __init__(self, register, lane, warp, block): + self.register = register + self.lane = lane + self.warp = warp + self.block = block + + def __str__(self): + return f"#{GPU_DIALECT}.linear<{{register={self.register}, lane={self.lane}, warp={self.warp}, block={self.block}}}>" + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +# ----------------------- +# test extract slice +# ----------------------- + +extract_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + # FIXME(Lezcano): This layout errors out + #BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] +blocked_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", + [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("extract_layout", extract_layout) +@pytest.mark.parametrize("blocked_layout", blocked_layout) +def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, + extract_layout, device='cuda'): + if not is_hip(): + pytest.skip("extract_slice is AMD specific instruction.") + + ir = f""" + #blocked = {blocked_layout} + #extract_layout = {extract_layout} + module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {str(64)} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %12 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> + %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #extract_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> + %14 = ttg.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + x = torch.randn((M, N), device=device, dtype=torch.float16) + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + extract_slice = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + + kernel[(1, 1, 1)](x.data_ptr(), extract_slice) + test_result = torch.equal(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], + extract_slice) + assert test_result + + +# ----------------------- +# test concat op +# ----------------------- + +src_layout = [ + LinearLayout(register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], + [16, 0], [0, 4]], warp=[[0, 32], + [32, 0]], + block=[]), + LinearLayout(register=[[1, 0], [2, 0], [4, 0]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], + warp=[[0, 16]], block=[]), +] + +dst_layout = [ + LinearLayout(register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], + lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]), + LinearLayout(register=[[1, 0], [2, 0], [4, 0], [32, 0], [0, 32]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], + [16, 0]], warp=[[0, 16]], block=[]), +] + + +@pytest.mark.parametrize( + "src_layout, dst_layout, M, N, M_tile_size, N_tile_size", + [[src_layout[0], dst_layout[0], 128, 128, 256, 256], [src_layout[1], dst_layout[1], 32, 32, 64, 64]]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_concat_op(dtype, M, N, M_tile_size, N_tile_size, src_layout, dst_layout, device='cuda'): + if not is_hip(): + pytest.skip("concat op is AMD specific instruction.") + + ir = f""" + #blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[16, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}}> + #src_layout = {src_layout} + #dst_layout = {dst_layout} + + module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {str(64)} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg3: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg4: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %100 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %101 = tt.splat %arg2 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %102 = tt.splat %arg3 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %34 = tt.splat %arg4 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %200 = tt.addptr %100, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %201 = tt.addptr %101, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %202 = tt.addptr %102, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %300 = tt.load %200 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %301 = tt.load %201 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %302 = tt.load %202 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + + %12 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #src_layout> + %400 = ttg.convert_layout %300 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #src_layout> + %401 = ttg.convert_layout %301 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #src_layout> + %402 = ttg.convert_layout %302 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #src_layout> + + %13 = amdgpu.concat %12, %400, %401, %402 : tensor<{M}x{N}xf16, #src_layout>, tensor<{M}x{N}xf16, #src_layout>, tensor<{M}x{N}xf16, #src_layout>, tensor<{M}x{N}xf16, #src_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #dst_layout> + %14 = ttg.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #dst_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + x1 = torch.randn((M, N), device=device, dtype=torch.float16) + x2 = torch.randn((M, N), device=device, dtype=torch.float16) + x3 = torch.randn((M, N), device=device, dtype=torch.float16) + x4 = torch.randn((M, N), device=device, dtype=torch.float16) + + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + concat = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + kernel[(1, 1, 1)](x1.data_ptr(), x2.data_ptr(), x3.data_ptr(), x4.data_ptr(), concat) + + top = torch.cat([x1, x2], dim=1) + bottom = torch.cat([x3, x4], dim=1) + result = torch.cat([top, bottom], dim=0) + + test_result = torch.equal(result, concat) + assert test_result diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index a4095a27ae75..24cabe6bb70a 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -76,8 +76,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_0("add_fold_true_cmpi", mlir::createTritonAMDGPUFoldTrueCmpIPass); - ADD_PASS_WRAPPER_1("add_block_pingpong", - mlir::createTritonAMDGPUBlockPingpongPass, int32_t); + ADD_PASS_WRAPPER_2("add_block_pingpong", + mlir::createTritonAMDGPUBlockPingpongPass, int32_t, bool); ADD_PASS_WRAPPER_4("add_stream_pipeline", mlir::createTritonAMDGPUStreamPipelinePass, int, int, int, bool); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 6b9659d9d4c2..577db1c0b543 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -241,8 +241,8 @@ LogicalResult lowerDistributedToSharedStmatrix( for (int i = 0; i < srcVals.size(); i += step) { auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second; Value offset = b.xor_(regBase, b.i32_val(regIdx)); - auto vecAddr = b.gep(smemPtrTy, llvmElemTy, smemBase, offset, - LLVM::GEPNoWrapFlags::inbounds); + auto vecAddr = b.gep(smemPtrTy, llvmElemTy, smemBase, offset); + vecAddr.setInbounds(true); SmallVector inValsVec; for (int j = 0; j < step; j++) inValsVec.push_back(srcVals[i + j]); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index e3b5ef77b7cf..a4673738dc67 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -230,7 +230,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, SmallVector vals = unpackLLVector(loc, val, rewriter); for (int i = 0; i < vec / maxVec; i++) { auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), - LLVM::GEPNoWrapFlags::inbounds); + /*inbounds=*/true); storeDShared( rewriter, loc, newPtr, ctaId, packLLVector(loc, ArrayRef(vals).slice(i * maxVec, maxVec), rewriter), @@ -343,7 +343,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, SmallVector vals; for (int i = 0; i < vec / maxVec; i++) { auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), - LLVM::GEPNoWrapFlags::inbounds); + /*inbounds=*/true); auto newVal = loadDShared(rewriter, loc, newPtr, ctaId, vec_ty(elemTy, maxVec), pred); for (Value v : unpackLLVector(loc, newVal, rewriter)) {