|
| 1 | +import math |
| 2 | +import neuronxcc.nki as nki |
| 3 | +import neuronxcc.nki.language as nl |
| 4 | +import torch_xla.core.xla_model as xm |
| 5 | + |
| 6 | +@nki.jit |
| 7 | +def nki_softmax_kernel(a_tensor): |
| 8 | + # Calculate out_tensor |
| 9 | + # Where softmax(x) = = exp(x - max(x)) / sum(exp(x - max(x))) |
| 10 | + out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, |
| 11 | + buffer=nl.shared_hbm) |
| 12 | + |
| 13 | + # Generate tensor indices to index input tensor |
| 14 | + ix = nl.arange(128)[:, None] |
| 15 | + iy = nl.arange(a_tensor.shape[1])[None, :] |
| 16 | + |
| 17 | + num_rows = a_tensor.shape[0] |
| 18 | + |
| 19 | + # Process 128 rows at a time due to 128-partition tile size limitation |
| 20 | + # Since we're not reducing across the first dimension |
| 21 | + # Tiles can be processed independently |
| 22 | + for i in nl.affine_range(math.ceil(a_tensor.shape[0]/128)): |
| 23 | + |
| 24 | + # Load input data from external memory to on-chip memory |
| 25 | + a_tile = nl.load(a_tensor[i * 128 + ix, iy], |
| 26 | + mask=(i * 128 + ix < num_rows)) |
| 27 | + |
| 28 | + # Find max and subtract from each value to ensure numerical stability |
| 29 | + max_vals = nl.max(a_tile, axis=[1], keepdims=True, mask=(i * 128 + ix < num_rows)) |
| 30 | + shifted = nl.subtract(a_tile, max_vals, mask=(i * 128 + ix < num_rows)) |
| 31 | + |
| 32 | + # Compute element-wise exp of a_tensor |
| 33 | + numerator = nl.exp(shifted) |
| 34 | + |
| 35 | + # Calculate sum of squared elements, along last dimension |
| 36 | + denominator = nl.sum(numerator, axis=[1]) |
| 37 | + |
| 38 | + # Scale and get a reciprocal |
| 39 | + sm = numerator / denominator |
| 40 | + |
| 41 | + # store the results back to external memory (out_tensor) |
| 42 | + nl.store(out_tensor[i * 128 + ix, iy], value=sm, |
| 43 | + mask=(i * 128 + ix < num_rows)) |
| 44 | + |
| 45 | + return out_tensor |
0 commit comments