|
| 1 | +model = Ortex.load("./models/stability-lm-3b/stability-lm-tuned-3b.onnx") |
| 2 | + |
| 3 | +prompt = "<|SYSTEM|># StableLM Tuned (Alpha version) |
| 4 | +- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. |
| 5 | +- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. |
| 6 | +- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. |
| 7 | +- StableLM will refuse to participate in anything that could harm a human. |
| 8 | +<|USER|>How are you feeling? <|ASSISTANT|> |
| 9 | +" |
| 10 | + |
| 11 | +{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-3b") |
| 12 | +{:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, prompt) |
| 13 | + |
| 14 | +input = Nx.tensor([Tokenizers.Encoding.get_ids(encoding)]) |
| 15 | +mask = Nx.tensor([Tokenizers.Encoding.get_attention_mask(encoding)]) |
| 16 | + |
| 17 | +defmodule M do |
| 18 | + def generate(_model, input, _mask, 500) do |
| 19 | + input |
| 20 | + end |
| 21 | + |
| 22 | + def generate(model, input, mask, iter) do |
| 23 | + [output | _] = |
| 24 | + Ortex.run(model, { |
| 25 | + input, |
| 26 | + mask |
| 27 | + }) |
| 28 | + |> Tuple.to_list() |
| 29 | + |
| 30 | + x = output |> Nx.backend_transfer() |> Nx.argmax(axis: 2) |
| 31 | + last = x[[.., -1]] |> Nx.new_axis(0) |
| 32 | + IO.inspect(last[0][0] |> Nx.to_number) |
| 33 | + |
| 34 | + case Enum.member?([50278, 50279, 50277, 1, 0], last[0][0] |> Nx.to_number) do |
| 35 | + true -> |
| 36 | + input |
| 37 | + |
| 38 | + false -> |
| 39 | + generate( |
| 40 | + model, |
| 41 | + Nx.concatenate([input, last], axis: 1), |
| 42 | + Nx.concatenate([mask, Nx.tensor([[1]])], axis: 1), |
| 43 | + iter + 1 |
| 44 | + ) |
| 45 | + end |
| 46 | + end |
| 47 | +end |
| 48 | + |
| 49 | +result = M.generate(model, input, mask, 0) |
| 50 | +IO.inspect(result) |
| 51 | + |
| 52 | +IO.inspect( |
| 53 | + Tokenizers.Tokenizer.decode( |
| 54 | + tokenizer, |
| 55 | + result |
| 56 | + |> Nx.backend_transfer() |
| 57 | + |> Nx.to_batched(1) |
| 58 | + |> Enum.map(&Nx.to_flat_list/1) |
| 59 | + ) |
| 60 | +) |
0 commit comments