|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy |
| 6 | +from liger_kernel.ops.utils import ensure_contiguous |
| 7 | +from liger_kernel.ops.utils import get_npu_core_count |
| 8 | + |
| 9 | + |
| 10 | +@triton.jit |
| 11 | +def embedding_forward_kernel( |
| 12 | + embeddings_ptr, |
| 13 | + indices_ptr, |
| 14 | + output_ptr, |
| 15 | + n_elements, |
| 16 | + embedding_dim: tl.constexpr, |
| 17 | + BLOCK_SIZE_M: tl.constexpr, |
| 18 | + BLOCK_SIZE_N: tl.constexpr, |
| 19 | + NUM_STAGES: tl.constexpr, |
| 20 | +): |
| 21 | + pid = tl.program_id(0) |
| 22 | + num_progs = tl.num_programs(0) |
| 23 | + |
| 24 | + grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M) |
| 25 | + grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N) |
| 26 | + total_2d_blocks = grid_m * grid_n |
| 27 | + |
| 28 | + for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES): |
| 29 | + block_m = block_idx // grid_n |
| 30 | + block_n = block_idx % grid_n |
| 31 | + |
| 32 | + start_m = block_m * BLOCK_SIZE_M |
| 33 | + start_n = block_n * BLOCK_SIZE_N |
| 34 | + |
| 35 | + offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) |
| 36 | + mask_m = offsets_m < n_elements |
| 37 | + |
| 38 | + indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) |
| 39 | + |
| 40 | + offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) |
| 41 | + mask_n = offsets_n < embedding_dim |
| 42 | + |
| 43 | + block_mask = mask_m[:, None] & mask_n[None, :] |
| 44 | + |
| 45 | + embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] |
| 46 | + embeddings = tl.load( |
| 47 | + embeddings_ptr + embedding_offsets, |
| 48 | + mask=block_mask, |
| 49 | + other=0.0, |
| 50 | + ) |
| 51 | + |
| 52 | + output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] |
| 53 | + tl.store( |
| 54 | + output_ptr + output_offsets, |
| 55 | + embeddings, |
| 56 | + mask=block_mask, |
| 57 | + ) |
| 58 | + |
| 59 | + |
| 60 | +@triton.jit |
| 61 | +def embedding_backward_kernel( |
| 62 | + grad_output_ptr, |
| 63 | + grad_weight_ptr, |
| 64 | + indices_ptr, |
| 65 | + n_elements, |
| 66 | + embedding_dim: tl.constexpr, |
| 67 | + BLOCK_SIZE_M: tl.constexpr, |
| 68 | + BLOCK_SIZE_N: tl.constexpr, |
| 69 | + NUM_STAGES: tl.constexpr, |
| 70 | +): |
| 71 | + pid = tl.program_id(0) |
| 72 | + num_progs = tl.num_programs(0) |
| 73 | + |
| 74 | + grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M) |
| 75 | + grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N) |
| 76 | + total_2d_blocks = grid_m * grid_n |
| 77 | + |
| 78 | + for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES): |
| 79 | + block_m = block_idx // grid_n |
| 80 | + block_n = block_idx % grid_n |
| 81 | + |
| 82 | + start_m = block_m * BLOCK_SIZE_M |
| 83 | + start_n = block_n * BLOCK_SIZE_N |
| 84 | + |
| 85 | + offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) |
| 86 | + mask_m = offsets_m < n_elements |
| 87 | + |
| 88 | + indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) |
| 89 | + |
| 90 | + offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) |
| 91 | + mask_n = offsets_n < embedding_dim |
| 92 | + |
| 93 | + block_mask = mask_m[:, None] & mask_n[None, :] |
| 94 | + |
| 95 | + grad_output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] |
| 96 | + grad_output = tl.load( |
| 97 | + grad_output_ptr + grad_output_offsets, |
| 98 | + mask=block_mask, |
| 99 | + other=0.0, |
| 100 | + ) |
| 101 | + |
| 102 | + grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] |
| 103 | + tl.atomic_add( |
| 104 | + grad_weight_ptr + grad_weight_offsets, |
| 105 | + grad_output, |
| 106 | + mask=block_mask, |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +def get_optimal_block_size(total_elements, dtype_size, BLOCK_SIZE_N: tl.constexpr): |
| 111 | + # 1. Set Memory Multiplier |
| 112 | + # 3.0 are empirical values based on 910B UB (192KB) |
| 113 | + # embedding_offsets, embedding_offsets : BLOCK_SIZE_N * BLOCK_SIZE_M (total 2 * BLOCK_SIZE_N * BLOCK_SIZE_M) |
| 114 | + # Reserve a unit of space for the remaining one-dimensional ub to occupy. |
| 115 | + # A conservative estimate of the total space occupation is 3 * BLOCK_SIZE_N * BLOCK_SIZE_M |
| 116 | + multiplier = 3.0 |
| 117 | + |
| 118 | + # 2. Call calculation function |
| 119 | + # Treat input as 1D (total_elements,), only tiling on dim 0 |
| 120 | + tile_shapes = compute_default_tiling_strategy( |
| 121 | + safety_margin=0.9, |
| 122 | + dtype_size=dtype_size, |
| 123 | + memory_multiplier=multiplier, |
| 124 | + shapes=((total_elements, BLOCK_SIZE_N),), |
| 125 | + tiling_dims=(0,), |
| 126 | + ) |
| 127 | + |
| 128 | + # 3. Parse result |
| 129 | + if tile_shapes and len(tile_shapes) > 0: |
| 130 | + block_size = tile_shapes[0][0] |
| 131 | + return block_size |
| 132 | + else: |
| 133 | + return triton.next_power_of_2(min(128, total_elements)) |
| 134 | + |
| 135 | + |
| 136 | +def embedding_forward(embeddings, indices): |
| 137 | + ori_shape = indices.shape |
| 138 | + indices = indices.view(-1) |
| 139 | + |
| 140 | + n_elements = indices.numel() |
| 141 | + embedding_dim = embeddings.shape[1] |
| 142 | + output = torch.empty( |
| 143 | + indices.shape[0], |
| 144 | + embeddings.shape[1], |
| 145 | + device=indices.device, |
| 146 | + dtype=embeddings.dtype, |
| 147 | + ) |
| 148 | + |
| 149 | + # Due to the involvement of two-dimensional partitioning, |
| 150 | + # the sizes of block_m and block_n in the ub space will influence each other. |
| 151 | + # Considering that embedding_dim is usually relatively smaller in most cases, |
| 152 | + # a value is first assigned to block_n, and then the largest possible block_m is used. |
| 153 | + BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) |
| 154 | + BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N) |
| 155 | + num_cores = get_npu_core_count() |
| 156 | + total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N) |
| 157 | + grid = min(num_cores, total_blocks) |
| 158 | + |
| 159 | + embedding_forward_kernel[(grid,)]( |
| 160 | + embeddings, |
| 161 | + indices, |
| 162 | + output, |
| 163 | + n_elements, |
| 164 | + embedding_dim=embedding_dim, |
| 165 | + BLOCK_SIZE_M=BLOCK_SIZE_M, |
| 166 | + BLOCK_SIZE_N=BLOCK_SIZE_N, |
| 167 | + NUM_STAGES=3, |
| 168 | + ) |
| 169 | + |
| 170 | + return output.view(*ori_shape, -1) |
| 171 | + |
| 172 | + |
| 173 | +def embedding_backward(embeddings, indices, grad_output): |
| 174 | + grad_output = grad_output.contiguous().view(-1, embeddings.shape[1]) |
| 175 | + |
| 176 | + grad_weight = torch.zeros_like(embeddings) |
| 177 | + |
| 178 | + n_elements = indices.numel() |
| 179 | + embedding_dim = embeddings.shape[1] |
| 180 | + BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) |
| 181 | + BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N) |
| 182 | + num_cores = get_npu_core_count() |
| 183 | + total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N) |
| 184 | + grid = min(num_cores, total_blocks) |
| 185 | + |
| 186 | + embedding_backward_kernel[(grid,)]( |
| 187 | + grad_output, |
| 188 | + grad_weight, |
| 189 | + indices, |
| 190 | + n_elements, |
| 191 | + embedding_dim=embedding_dim, |
| 192 | + BLOCK_SIZE_M=BLOCK_SIZE_M, |
| 193 | + BLOCK_SIZE_N=BLOCK_SIZE_N, |
| 194 | + NUM_STAGES=3, |
| 195 | + ) |
| 196 | + |
| 197 | + return grad_weight |
| 198 | + |
| 199 | + |
| 200 | +class LigerEmbeddingFunction(torch.autograd.Function): |
| 201 | + @staticmethod |
| 202 | + @ensure_contiguous |
| 203 | + def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor): |
| 204 | + output = embedding_forward(embeddings, indices) |
| 205 | + ctx.save_for_backward(indices, embeddings) |
| 206 | + return output |
| 207 | + |
| 208 | + @staticmethod |
| 209 | + @ensure_contiguous |
| 210 | + def backward(ctx, grad_output: torch.Tensor): |
| 211 | + indices, embeddings = ctx.saved_tensors |
| 212 | + grad_weight = embedding_backward(embeddings, indices, grad_output) |
| 213 | + |
| 214 | + return grad_weight, None |
0 commit comments