Skip to content

Commit c1ff124

Browse files
geekbeastMatthew Tamayo-Rios
andauthored
Support loading models by name (WebAssembly#38)
Co-authored-by: Matthew Tamayo-Rios <[email protected]>
1 parent f47f35c commit c1ff124

File tree

4 files changed

+111
-1
lines changed

4 files changed

+111
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ detailed examples.
123123

124124
### Detailed design discussion
125125

126-
For the details of the API, see [wasi-nn.wit.md](wasi-nn.wit.md).
126+
For the details of the API, see [wasi-nn.wit](wit/wasi-nn.wit).
127127

128128
<!--
129129
This section should mostly refer to the .wit.md file that specifies the API. This section is for
File renamed without changes.

wasi-nn.witx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
;; version for the official specification and documentation.
33

44
(typename $buffer_size u32)
5+
(typename $status u32)
6+
57
(typename $nn_errno
68
(enum (@witx tag u16)
79
$success
@@ -10,6 +12,9 @@
1012
$missing_memory
1113
$busy
1214
$runtime_error
15+
$unsupported_operation
16+
$model_too_large
17+
$model_not_found
1318
)
1419
)
1520
(typename $tensor_dimensions (list u32))
@@ -39,6 +44,7 @@
3944
$tensorflow
4045
$pytorch
4146
$tensorflowlite
47+
$autodetect
4248
)
4349
)
4450
(typename $execution_target
@@ -58,6 +64,17 @@
5864
(param $target $execution_target)
5965
(result $error (expected $graph (error $nn_errno)))
6066
)
67+
;;; Load an opaque sequence of bytes to use for inference.
68+
;;;
69+
;;; This allows runtime implementations to support multiple graph encoding formats. For unsupported graph encodings,
70+
;;; return `errno::inval`.
71+
(@interface func (export "load_by_name")
72+
;;; The name of the model to load from the model registry
73+
(param $model_name string)
74+
75+
(result $error (expected $graph (error $nn_errno)))
76+
)
77+
6178
(@interface func (export "init_execution_context")
6279
(param $graph $graph)
6380
(result $error (expected $graph_execution_context (error $nn_errno)))

wit/wasi-nn.wit

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package wasi:nn
2+
3+
world inference {
4+
import tensor
5+
import graph
6+
import execution
7+
import errors
8+
}
9+
10+
interface tensor {
11+
type tensor-dimensions = list<u32>
12+
type tensor-data = list<u8>
13+
14+
enum tensor-type {
15+
fp16,
16+
fp32,
17+
bf16,
18+
up8,
19+
ip32
20+
}
21+
22+
record tensor {
23+
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
24+
// containing a single value, use `[1]` for the tensor dimensions.
25+
dimensions: tensor-dimensions,
26+
27+
// Describe the type of element in the tensor (e.g., f32).
28+
tensor-type: tensor-type,
29+
30+
// Contains the tensor data.
31+
data: tensor-data,
32+
}
33+
}
34+
35+
interface graph {
36+
use errors.{error}
37+
type graph-builder = list<u8>
38+
type graph-builder-array = list<graph-builder>
39+
use tensor.{tensor}
40+
41+
type graph = u32
42+
43+
enum graph-encoding {
44+
openvino,
45+
onnx,
46+
tensorflow,
47+
pytorch,
48+
tensorflowlite,
49+
autodetect,
50+
}
51+
52+
enum execution-target {
53+
cpu,
54+
gpu,
55+
tpu
56+
}
57+
58+
load: func(builder: graph-builder-array, encoding: graph-encoding, target: execution-target) -> result<graph, error>
59+
load-named-model: func(name: string) -> result<graph, error>
60+
}
61+
62+
interface execution {
63+
use errors.{error}
64+
use tensor.{tensor, tensor-data}
65+
use graph.{graph}
66+
67+
type graph-execution-context = u32
68+
init-execution-context: func(graph: graph) -> result<graph-execution-context, error>
69+
set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error>
70+
set-input-by-name: func(ctx: graph-execution-context, name: string, tensor: tensor) -> result<_, error>
71+
compute: func(ctx: graph-execution-context) -> result<_, error>
72+
get-output: func(ctx: graph-execution-context, index: u32) -> result<list<tensor-data>, error>
73+
eval: func(tensors: list<tensor>) -> result<list<tensor>, error>
74+
75+
}
76+
77+
interface errors {
78+
enum error {
79+
// Caller module passed an invalid argument.
80+
invalid-argument,
81+
// Invalid encoding.
82+
invalid-encoding,
83+
busy,
84+
// Runtime Error.
85+
runtime-error,
86+
// Unsupported operation
87+
unsupported-operation,
88+
// Model too large
89+
model-too-large,
90+
// Model not found
91+
model-not-found
92+
}
93+
}

0 commit comments

Comments
 (0)