Skip to content

Commit 57e98d3

Browse files
authored
[NPU]: Add NPU support for the embedding (#1028)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Add NPU support for the embedding. - Implements a flattened, grid-stride Triton kernel for embedding forward/backward to improve scalability and reduce launch overhead on Ascend NPUs. - Uses UB-aware tiling (compute_default_tiling_strategy) and NPU vector core count to dynamically select block size and grid size for better performance stability. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> I tested swiglu by following method and all cases passed: - `python benchmark/scripts/benchmark_embedding.py` - `pytest -v test/transformers/test_embedding.py` <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: Ascend NPU 910B4 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 393efae commit 57e98d3

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
If __all__ is not defined, all public symbols will be auto-discovered.
1515
"""
1616

17+
from liger_kernel.ops.backends._ascend.ops.embedding import LigerEmbeddingFunction
18+
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_backward
19+
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_forward
1720
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
1821
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
1922
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
@@ -31,6 +34,9 @@
3134
from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton
3235

3336
__all__ = [
37+
"LigerEmbeddingFunction",
38+
"embedding_forward",
39+
"embedding_backward",
3440
"LigerGELUMulFunction",
3541
"geglu_forward",
3642
"geglu_backward",
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6+
from liger_kernel.ops.utils import ensure_contiguous
7+
from liger_kernel.ops.utils import get_npu_core_count
8+
9+
10+
@triton.jit
11+
def embedding_forward_kernel(
12+
embeddings_ptr,
13+
indices_ptr,
14+
output_ptr,
15+
n_elements,
16+
embedding_dim: tl.constexpr,
17+
BLOCK_SIZE_M: tl.constexpr,
18+
BLOCK_SIZE_N: tl.constexpr,
19+
NUM_STAGES: tl.constexpr,
20+
):
21+
pid = tl.program_id(0)
22+
num_progs = tl.num_programs(0)
23+
24+
grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
25+
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
26+
total_2d_blocks = grid_m * grid_n
27+
28+
for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
29+
block_m = block_idx // grid_n
30+
block_n = block_idx % grid_n
31+
32+
start_m = block_m * BLOCK_SIZE_M
33+
start_n = block_n * BLOCK_SIZE_N
34+
35+
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
36+
mask_m = offsets_m < n_elements
37+
38+
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
39+
40+
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
41+
mask_n = offsets_n < embedding_dim
42+
43+
block_mask = mask_m[:, None] & mask_n[None, :]
44+
45+
embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
46+
embeddings = tl.load(
47+
embeddings_ptr + embedding_offsets,
48+
mask=block_mask,
49+
other=0.0,
50+
)
51+
52+
output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
53+
tl.store(
54+
output_ptr + output_offsets,
55+
embeddings,
56+
mask=block_mask,
57+
)
58+
59+
60+
@triton.jit
61+
def embedding_backward_kernel(
62+
grad_output_ptr,
63+
grad_weight_ptr,
64+
indices_ptr,
65+
n_elements,
66+
embedding_dim: tl.constexpr,
67+
BLOCK_SIZE_M: tl.constexpr,
68+
BLOCK_SIZE_N: tl.constexpr,
69+
NUM_STAGES: tl.constexpr,
70+
):
71+
pid = tl.program_id(0)
72+
num_progs = tl.num_programs(0)
73+
74+
grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
75+
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
76+
total_2d_blocks = grid_m * grid_n
77+
78+
for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
79+
block_m = block_idx // grid_n
80+
block_n = block_idx % grid_n
81+
82+
start_m = block_m * BLOCK_SIZE_M
83+
start_n = block_n * BLOCK_SIZE_N
84+
85+
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
86+
mask_m = offsets_m < n_elements
87+
88+
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
89+
90+
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
91+
mask_n = offsets_n < embedding_dim
92+
93+
block_mask = mask_m[:, None] & mask_n[None, :]
94+
95+
grad_output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
96+
grad_output = tl.load(
97+
grad_output_ptr + grad_output_offsets,
98+
mask=block_mask,
99+
other=0.0,
100+
)
101+
102+
grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
103+
tl.atomic_add(
104+
grad_weight_ptr + grad_weight_offsets,
105+
grad_output,
106+
mask=block_mask,
107+
)
108+
109+
110+
def get_optimal_block_size(total_elements, dtype_size, BLOCK_SIZE_N: tl.constexpr):
111+
# 1. Set Memory Multiplier
112+
# 3.0 are empirical values based on 910B UB (192KB)
113+
# embedding_offsets, embedding_offsets : BLOCK_SIZE_N * BLOCK_SIZE_M (total 2 * BLOCK_SIZE_N * BLOCK_SIZE_M)
114+
# Reserve a unit of space for the remaining one-dimensional ub to occupy.
115+
# A conservative estimate of the total space occupation is 3 * BLOCK_SIZE_N * BLOCK_SIZE_M
116+
multiplier = 3.0
117+
118+
# 2. Call calculation function
119+
# Treat input as 1D (total_elements,), only tiling on dim 0
120+
tile_shapes = compute_default_tiling_strategy(
121+
safety_margin=0.9,
122+
dtype_size=dtype_size,
123+
memory_multiplier=multiplier,
124+
shapes=((total_elements, BLOCK_SIZE_N),),
125+
tiling_dims=(0,),
126+
)
127+
128+
# 3. Parse result
129+
if tile_shapes and len(tile_shapes) > 0:
130+
block_size = tile_shapes[0][0]
131+
return block_size
132+
else:
133+
return triton.next_power_of_2(min(128, total_elements))
134+
135+
136+
def embedding_forward(embeddings, indices):
137+
ori_shape = indices.shape
138+
indices = indices.view(-1)
139+
140+
n_elements = indices.numel()
141+
embedding_dim = embeddings.shape[1]
142+
output = torch.empty(
143+
indices.shape[0],
144+
embeddings.shape[1],
145+
device=indices.device,
146+
dtype=embeddings.dtype,
147+
)
148+
149+
# Due to the involvement of two-dimensional partitioning,
150+
# the sizes of block_m and block_n in the ub space will influence each other.
151+
# Considering that embedding_dim is usually relatively smaller in most cases,
152+
# a value is first assigned to block_n, and then the largest possible block_m is used.
153+
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
154+
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
155+
num_cores = get_npu_core_count()
156+
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
157+
grid = min(num_cores, total_blocks)
158+
159+
embedding_forward_kernel[(grid,)](
160+
embeddings,
161+
indices,
162+
output,
163+
n_elements,
164+
embedding_dim=embedding_dim,
165+
BLOCK_SIZE_M=BLOCK_SIZE_M,
166+
BLOCK_SIZE_N=BLOCK_SIZE_N,
167+
NUM_STAGES=3,
168+
)
169+
170+
return output.view(*ori_shape, -1)
171+
172+
173+
def embedding_backward(embeddings, indices, grad_output):
174+
grad_output = grad_output.contiguous().view(-1, embeddings.shape[1])
175+
176+
grad_weight = torch.zeros_like(embeddings)
177+
178+
n_elements = indices.numel()
179+
embedding_dim = embeddings.shape[1]
180+
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
181+
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
182+
num_cores = get_npu_core_count()
183+
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
184+
grid = min(num_cores, total_blocks)
185+
186+
embedding_backward_kernel[(grid,)](
187+
grad_output,
188+
grad_weight,
189+
indices,
190+
n_elements,
191+
embedding_dim=embedding_dim,
192+
BLOCK_SIZE_M=BLOCK_SIZE_M,
193+
BLOCK_SIZE_N=BLOCK_SIZE_N,
194+
NUM_STAGES=3,
195+
)
196+
197+
return grad_weight
198+
199+
200+
class LigerEmbeddingFunction(torch.autograd.Function):
201+
@staticmethod
202+
@ensure_contiguous
203+
def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
204+
output = embedding_forward(embeddings, indices)
205+
ctx.save_for_backward(indices, embeddings)
206+
return output
207+
208+
@staticmethod
209+
@ensure_contiguous
210+
def backward(ctx, grad_output: torch.Tensor):
211+
indices, embeddings = ctx.saved_tensors
212+
grad_weight = embedding_backward(embeddings, indices, grad_output)
213+
214+
return grad_weight, None

0 commit comments

Comments
 (0)