Skip to content

Commit 074b98b

Browse files
committed
added stablelm example
1 parent e570112 commit 074b98b

File tree

4 files changed

+90
-0
lines changed

4 files changed

+90
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ iex> output |> Nx.backend_transfer(Nx.BinaryBackend) |> Nx.argmax
2626
Inspecting a model shows the expected inputs, outputs, data types, and shapes. Axes with
2727
`nil` represent a dynamic size.
2828

29+
To see more real world examples see `examples`.
30+
2931
### Serving
3032
`Ortex` also implements `Nx.Serving` behaviour. To use it in your application's
3133
supervision tree consult the `Nx.Serving` docs.

examples/stablelm/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# StableLM
2+
3+
Run `python export.py` to create the ONNX model for stablelm-3b, copy the model to the
4+
models directory (or change where `stablelm.exs` loads the model from), then `mix run`
5+
the `stablelm.exs` script.

examples/stablelm/export.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from transformers import (
2+
AutoModelForCausalLM,
3+
AutoTokenizer,
4+
)
5+
import torch
6+
7+
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-3b")
8+
model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-3b")
9+
print(model)
10+
11+
prompt = "<|ASSISTANT|>"
12+
13+
inputs = tokenizer(prompt, return_tensors="pt")
14+
torch.onnx.export(
15+
model,
16+
(inputs["input_ids"].cpu(), inputs["attention_mask"].cpu()),
17+
"output/stability-lm-tuned-3b.onnx",
18+
input_names=["input_ids", "attention_mask"],
19+
dynamic_axes={
20+
"input_ids": {0: "batch_size", 1: "sequence_length"},
21+
"attention_mask": {0: "batch_size", 1: "sequence_length"},
22+
},
23+
)

examples/stablelm/stablelm.exs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
model = Ortex.load("./models/stability-lm-3b/stability-lm-tuned-3b.onnx")
2+
3+
prompt = "<|SYSTEM|># StableLM Tuned (Alpha version)
4+
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
5+
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
6+
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
7+
- StableLM will refuse to participate in anything that could harm a human.
8+
<|USER|>How are you feeling? <|ASSISTANT|>
9+
"
10+
11+
{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-3b")
12+
{:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, prompt)
13+
14+
input = Nx.tensor([Tokenizers.Encoding.get_ids(encoding)])
15+
mask = Nx.tensor([Tokenizers.Encoding.get_attention_mask(encoding)])
16+
17+
defmodule M do
18+
def generate(_model, input, _mask, 500) do
19+
input
20+
end
21+
22+
def generate(model, input, mask, iter) do
23+
[output | _] =
24+
Ortex.run(model, {
25+
input,
26+
mask
27+
})
28+
|> Tuple.to_list()
29+
30+
x = output |> Nx.backend_transfer() |> Nx.argmax(axis: 2)
31+
last = x[[.., -1]] |> Nx.new_axis(0)
32+
IO.inspect(last[0][0] |> Nx.to_number)
33+
34+
case Enum.member?([50278, 50279, 50277, 1, 0], last[0][0] |> Nx.to_number) do
35+
true ->
36+
input
37+
38+
false ->
39+
generate(
40+
model,
41+
Nx.concatenate([input, last], axis: 1),
42+
Nx.concatenate([mask, Nx.tensor([[1]])], axis: 1),
43+
iter + 1
44+
)
45+
end
46+
end
47+
end
48+
49+
result = M.generate(model, input, mask, 0)
50+
IO.inspect(result)
51+
52+
IO.inspect(
53+
Tokenizers.Tokenizer.decode(
54+
tokenizer,
55+
result
56+
|> Nx.backend_transfer()
57+
|> Nx.to_batched(1)
58+
|> Enum.map(&Nx.to_flat_list/1)
59+
)
60+
)

0 commit comments

Comments
 (0)