Skip to content

Commit 8466576

Browse files
committed
added tests
1 parent 96d8bca commit 8466576

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

test/ortex_test.exs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
defmodule OrtexTest do
2+
use ExUnit.Case
3+
doctest Ortex
4+
5+
test "resnet50" do
6+
model = Ortex.load("./models/resnet50.onnx")
7+
8+
input = Nx.broadcast(0.0, {1, 3, 224, 224})
9+
{output} = Ortex.run(model, {input})
10+
argmax = output |> Nx.backend_transfer() |> Nx.argmax(axis: 1)
11+
12+
assert argmax == Nx.tensor([499])
13+
end
14+
15+
test "transfer to Ortex.Backend" do
16+
assert true
17+
end
18+
19+
test "transfer from Ortex.Backend" do
20+
assert true
21+
end
22+
23+
test "Nx.Serving with resnet50" do
24+
model = Ortex.load("./models/resnet50.onnx")
25+
26+
serving = Nx.Serving.new(Ortex.Serving, model)
27+
batch = Nx.Batch.stack([{Nx.broadcast(0.0, {3, 224, 224})}])
28+
{result} = Nx.Serving.run(serving, batch)
29+
assert result |> Nx.backend_transfer() |> Nx.argmax(axis: 1) == Nx.tensor([499])
30+
end
31+
32+
test "Nx.Serving with tinymodel" do
33+
model = Ortex.load("./models/tinymodel.onnx")
34+
35+
serving = Nx.Serving.new(Ortex.Serving, model)
36+
37+
# Create a batch of size 3 with {int32, float32} inputs
38+
batch =
39+
Nx.Batch.stack([
40+
{Nx.broadcast(0, {100}) |> Nx.as_type(:s32),
41+
Nx.broadcast(0.0, {100}) |> Nx.as_type(:f32)},
42+
{Nx.broadcast(1, {100}) |> Nx.as_type(:s32),
43+
Nx.broadcast(1.0, {100}) |> Nx.as_type(:f32)},
44+
{Nx.broadcast(2, {100}) |> Nx.as_type(:s32), Nx.broadcast(2.0, {100}) |> Nx.as_type(:f32)}
45+
])
46+
47+
{%Nx.Tensor{shape: {3, 10}}, %Nx.Tensor{shape: {3, 10}}, %Nx.Tensor{shape: {3, 10}}} =
48+
Nx.Serving.run(serving, batch)
49+
end
50+
end

test/test_helper.exs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ExUnit.start()

0 commit comments

Comments
 (0)