|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | +from liger_kernel.ops.utils import ensure_contiguous |
| 6 | +from liger_kernel.ops.utils import get_npu_core_count |
| 7 | + |
| 8 | + |
| 9 | +@triton.jit |
| 10 | +def _softmax_multi_block_forward_kernel( |
| 11 | + Y_ptr, |
| 12 | + Y_row_stride, |
| 13 | + X_ptr, |
| 14 | + X_row_stride, |
| 15 | + n_rows, |
| 16 | + n_cols, |
| 17 | + BLOCK_SIZE: tl.constexpr, |
| 18 | +): |
| 19 | + """ |
| 20 | + Multi-block softmax forward kernel using two-pass algorithm. |
| 21 | +
|
| 22 | + First pass computes max and sum for numerical stability. |
| 23 | + Second pass normalizes and writes output. |
| 24 | +
|
| 25 | + Args: |
| 26 | + Y_ptr: Output tensor pointer |
| 27 | + Y_row_stride: Stride for output rows |
| 28 | + X_ptr: Input tensor pointer |
| 29 | + X_row_stride: Stride for input rows |
| 30 | + n_rows: Number of rows to process |
| 31 | + n_cols: Number of columns per row |
| 32 | + BLOCK_SIZE: Block size for column processing |
| 33 | + """ |
| 34 | + row_start = tl.program_id(0) |
| 35 | + row_step = tl.num_programs(0) |
| 36 | + |
| 37 | + for row_idx in tl.range(row_start, n_rows, row_step): |
| 38 | + row_start_ptr = X_ptr + row_idx * X_row_stride |
| 39 | + col_offsets = tl.arange(0, BLOCK_SIZE) |
| 40 | + m = float("-inf") |
| 41 | + d = 0.0 |
| 42 | + |
| 43 | + for start in tl.range(0, n_cols, BLOCK_SIZE): |
| 44 | + idx = start + col_offsets |
| 45 | + mask = idx < n_cols |
| 46 | + xblk = tl.load( |
| 47 | + row_start_ptr + idx, mask=mask, other=float("-inf"), eviction_policy="evict_first", cache_modifier=".ca" |
| 48 | + ) |
| 49 | + blk_max = tl.max(xblk, axis=0) |
| 50 | + new_m = tl.maximum(m, blk_max) |
| 51 | + d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0) |
| 52 | + m = new_m |
| 53 | + |
| 54 | + for start in tl.range(0, n_cols, BLOCK_SIZE): |
| 55 | + idx = start + col_offsets |
| 56 | + mask = idx < n_cols |
| 57 | + xblk = tl.load( |
| 58 | + row_start_ptr + idx, mask=mask, other=float("-inf"), eviction_policy="evict_first", cache_modifier=".ca" |
| 59 | + ) |
| 60 | + yblk = tl.exp(xblk - m) / d |
| 61 | + tl.store(Y_ptr + row_idx * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs") |
| 62 | + |
| 63 | + |
| 64 | +@triton.jit |
| 65 | +def _softmax_multi_block_backward_kernel( |
| 66 | + dy_ptr, |
| 67 | + dy_stride, |
| 68 | + y_ptr, |
| 69 | + y_stride, |
| 70 | + dx_ptr, |
| 71 | + dx_stride, |
| 72 | + n_rows, |
| 73 | + n_cols, |
| 74 | + BLOCK_SIZE: tl.constexpr, |
| 75 | +): |
| 76 | + """ |
| 77 | + Multi-block softmax backward kernel using two-pass algorithm. |
| 78 | +
|
| 79 | + Computes gradient: dx = y * (dy - sum(dy * y)) |
| 80 | +
|
| 81 | + Args: |
| 82 | + dy_ptr: Gradient output pointer |
| 83 | + dy_stride: Stride for gradient output rows |
| 84 | + y_ptr: Forward output pointer |
| 85 | + y_stride: Stride for forward output rows |
| 86 | + dx_ptr: Gradient input pointer |
| 87 | + dx_stride: Stride for gradient input rows |
| 88 | + n_rows: Number of rows to process |
| 89 | + n_cols: Number of columns per row |
| 90 | + BLOCK_SIZE: Block size for column processing |
| 91 | + """ |
| 92 | + row_start = tl.program_id(0) |
| 93 | + col_offsets = tl.arange(0, BLOCK_SIZE) |
| 94 | + acc = 0.0 |
| 95 | + row_step = tl.num_programs(0) |
| 96 | + |
| 97 | + for row_idx in tl.range(row_start, n_rows, row_step): |
| 98 | + dy_start_ptr = dy_ptr + row_idx * dy_stride |
| 99 | + y_start_ptr = y_ptr + row_idx * y_stride |
| 100 | + |
| 101 | + for start in tl.range(0, n_cols, BLOCK_SIZE): |
| 102 | + idx = start + col_offsets |
| 103 | + mask = idx < n_cols |
| 104 | + dy_blk = tl.load(dy_start_ptr + idx, mask=mask, other=0.0, eviction_policy="evict_first") |
| 105 | + y_blk = tl.load( |
| 106 | + y_start_ptr + idx, mask=mask, other=0.0, eviction_policy="evict_first", cache_modifier=".ca" |
| 107 | + ) |
| 108 | + acc += tl.sum(dy_blk * y_blk, axis=0) |
| 109 | + |
| 110 | + for start in tl.range(0, n_cols, BLOCK_SIZE): |
| 111 | + idx = start + col_offsets |
| 112 | + mask = idx < n_cols |
| 113 | + dy_blk = tl.load(dy_start_ptr + idx, mask=mask, other=0.0) |
| 114 | + y_blk = tl.load(y_start_ptr + idx, mask=mask, other=0.0, cache_modifier=".ca") |
| 115 | + dx_blk = y_blk * (dy_blk - acc) |
| 116 | + tl.store(dx_ptr + row_idx * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb") |
| 117 | + |
| 118 | + |
| 119 | +def softmax_forward(x): |
| 120 | + *batch, n_cols = x.shape |
| 121 | + x2d = x.contiguous().view(-1, n_cols) |
| 122 | + n_rows = x2d.shape[0] |
| 123 | + MAX_FUSED_BLOCK_SIZE = 8192 |
| 124 | + |
| 125 | + BLOCK_SIZE = triton.next_power_of_2(n_cols) |
| 126 | + BLOCK_SIZE = min(BLOCK_SIZE, MAX_FUSED_BLOCK_SIZE) |
| 127 | + |
| 128 | + y2d = torch.empty_like(x2d) |
| 129 | + num_cores = get_npu_core_count() |
| 130 | + num_programs = min(num_cores, n_rows) |
| 131 | + |
| 132 | + _softmax_multi_block_forward_kernel[(num_programs,)]( |
| 133 | + y2d, y2d.stride(0), x2d, x2d.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE |
| 134 | + ) |
| 135 | + |
| 136 | + return y2d.view(*batch, n_cols), BLOCK_SIZE |
| 137 | + |
| 138 | + |
| 139 | +def softmax_backward( |
| 140 | + dy: torch.Tensor, |
| 141 | + y: torch.Tensor, |
| 142 | + BLOCK_SIZE: int, |
| 143 | +) -> torch.Tensor: |
| 144 | + *batch, n_cols = dy.shape |
| 145 | + dy2d = dy.contiguous().view(-1, n_cols) |
| 146 | + y2d = y.contiguous().view(-1, n_cols) |
| 147 | + n_rows = dy2d.shape[0] |
| 148 | + dx2d = torch.empty_like(dy2d) |
| 149 | + |
| 150 | + num_cores = get_npu_core_count() |
| 151 | + num_programs = min(num_cores, n_rows) |
| 152 | + |
| 153 | + _softmax_multi_block_backward_kernel[(num_programs,)]( |
| 154 | + dy2d, |
| 155 | + dy2d.stride(0), |
| 156 | + y2d, |
| 157 | + y2d.stride(0), |
| 158 | + dx2d, |
| 159 | + dx2d.stride(0), |
| 160 | + n_rows, |
| 161 | + n_cols, |
| 162 | + BLOCK_SIZE=BLOCK_SIZE, |
| 163 | + ) |
| 164 | + |
| 165 | + return dx2d.view(*batch, n_cols) |
| 166 | + |
| 167 | + |
| 168 | +class LigerSoftmaxFunction(torch.autograd.Function): |
| 169 | + @staticmethod |
| 170 | + @ensure_contiguous |
| 171 | + def forward(ctx, input_: torch.Tensor): |
| 172 | + y, BLOCK_SIZE = softmax_forward(input_) |
| 173 | + ctx.save_for_backward(y) |
| 174 | + ctx.BLOCK_SIZE = BLOCK_SIZE |
| 175 | + return y |
| 176 | + |
| 177 | + @staticmethod |
| 178 | + @ensure_contiguous |
| 179 | + def backward(ctx, grad_output): |
| 180 | + (y,) = ctx.saved_tensors |
| 181 | + dx = softmax_backward( |
| 182 | + grad_output, |
| 183 | + y, |
| 184 | + ctx.BLOCK_SIZE, |
| 185 | + ) |
| 186 | + return dx |
0 commit comments