is it better to change return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab] by x[-1] @ wte.T ? then we can use next_id = np.argmax(logits)