Skip to content

Commit 3c5d277

Browse files
authored
feat: add softmax examples (#55)
1 parent 794d337 commit 3c5d277

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import math
2+
import neuronxcc.nki as nki
3+
import neuronxcc.nki.language as nl
4+
import torch_xla.core.xla_model as xm
5+
6+
@nki.jit
7+
def nki_softmax_kernel(a_tensor):
8+
# Calculate out_tensor
9+
# Where softmax(x) = = exp(x - max(x)) / sum(exp(x - max(x)))
10+
out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype,
11+
buffer=nl.shared_hbm)
12+
13+
# Generate tensor indices to index input tensor
14+
ix = nl.arange(128)[:, None]
15+
iy = nl.arange(a_tensor.shape[1])[None, :]
16+
17+
num_rows = a_tensor.shape[0]
18+
19+
# Process 128 rows at a time due to 128-partition tile size limitation
20+
# Since we're not reducing across the first dimension
21+
# Tiles can be processed independently
22+
for i in nl.affine_range(math.ceil(a_tensor.shape[0]/128)):
23+
24+
# Load input data from external memory to on-chip memory
25+
a_tile = nl.load(a_tensor[i * 128 + ix, iy],
26+
mask=(i * 128 + ix < num_rows))
27+
28+
# Find max and subtract from each value to ensure numerical stability
29+
max_vals = nl.max(a_tile, axis=[1], keepdims=True, mask=(i * 128 + ix < num_rows))
30+
shifted = nl.subtract(a_tile, max_vals, mask=(i * 128 + ix < num_rows))
31+
32+
# Compute element-wise exp of a_tensor
33+
numerator = nl.exp(shifted)
34+
35+
# Calculate sum of squared elements, along last dimension
36+
denominator = nl.sum(numerator, axis=[1])
37+
38+
# Scale and get a reciprocal
39+
sm = numerator / denominator
40+
41+
# store the results back to external memory (out_tensor)
42+
nl.store(out_tensor[i * 128 + ix, iy], value=sm,
43+
mask=(i * 128 + ix < num_rows))
44+
45+
return out_tensor
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from softmax_nki_kernels import nki_softmax_kernel
6+
7+
class NaiveSoftmax(nn.Module):
8+
def __init__(self):
9+
super(NaiveSoftmax, self).__init__()
10+
11+
def forward(self, x):
12+
13+
numerator = torch.exp(x)
14+
denominator = torch.sum(numerator, dim=-1, keepdim=True)
15+
sm = numerator / denominator
16+
return sm
17+
18+
def naive_softmax(logits: torch.tensor) -> torch.tensor :
19+
softmax = NaiveSoftmax()
20+
probs = softmax(logits)
21+
return probs
22+
23+
from torch_xla.core import xla_model as xm
24+
device = xm.xla_device()
25+
26+
logits = torch.tensor([[1.0,2.0,3.0,4.0,5.0], [5.0,4.0,3.0,2.0,1.0]]).to(device)
27+
28+
sm_naive = naive_softmax(logits)
29+
sm_nki = nki_softmax_kernel(logits)
30+
31+
assert torch.allclose(sm_naive, sm_nki, rtol=1e-5, atol=1e-5)

0 commit comments

Comments
 (0)