Skip to content

Commit 3c907e9

Browse files
committed
feat: implement WebNNProgram class for graph management and tensor operations in WebNNView
1 parent 0ab38b9 commit 3c907e9

File tree

2 files changed

+88
-53
lines changed

2 files changed

+88
-53
lines changed

app/web-nn/web-nn.js

Lines changed: 22 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
import { EventsManager } from "../../src/system/events-manager.js";
2-
import {
3-
WebNNModule,
4-
OperandDescriptorBuilder } from "../../src/modules/webnn.js";
2+
import { WebNNProgram } from "../../src/modules/webnn/webnn-program.js";
53

64
export default class WebNNView extends HTMLElement {
75
static tag = "webnn-view";
86

97
#eventsManager = new EventsManager();
10-
#graphBuilder
11-
#graph;
12-
#context;
13-
#tensors;
8+
#program = new WebNNProgram();
149

1510
constructor() {
1611
super();
@@ -22,65 +17,39 @@ export default class WebNNView extends HTMLElement {
2217
url: import.meta.url,
2318
});
2419

25-
requestAnimationFrame(() => {
20+
requestAnimationFrame(async () => {
2621
const button = this.shadowRoot.querySelector("button");
2722
this.#eventsManager.addEvent(button, "click", this.#submit.bind(this));
28-
})
29-
30-
this.#context = await WebNNModule.createContext();
31-
this.#graphBuilder = await WebNNModule.createGraph( { context: this.#context } );
32-
const operands = await this.#createGraph();
33-
await this.#createTensors(operands.A, operands.B, operands.C);
23+
24+
// 1. initialize
25+
await this.#program.init();
26+
27+
// 2. build graph
28+
const descriptor = {dataType: 'float32', shape: [1]};
29+
const A = this.#program.addToGraph("input", "A", descriptor);
30+
const B = this.#program.addToGraph("input", "B", descriptor);
31+
const C = this.#program.addToGraph("add", A, B);
32+
this.#program.build({C});
33+
34+
// 3. add input and output tensors
35+
await this.#program.addInputTensor("A", A);
36+
await this.#program.addInputTensor("B", B);
37+
await this.#program.addOutputTensor("C", C);
38+
})
3439
}
3540

3641
async disconnectedCallback() {
3742
this.#eventsManager = this.#eventsManager.dispose();
3843
}
3944

40-
async #createGraph() {
41-
const descriptor = new OperandDescriptorBuilder().build();
42-
const A = this.#graphBuilder.input("A", descriptor);
43-
const B = this.#graphBuilder.input("B", descriptor);
44-
const C = this.#graphBuilder.add(A, B);
45-
46-
this.#graph = await this.#graphBuilder.build({C});
47-
return { A, B, C };
48-
}
49-
50-
async #createTensors(inputA, inputB, outputC) {
51-
this.#tensors = {
52-
inputA: await this.#createTensor(this.#context, inputA, true, true),
53-
inputB: await this.#createTensor(this.#context, inputB, true, true),
54-
outputC: await this.#createTensor(this.#context, outputC, false, true)
55-
};
56-
}
57-
58-
async #createTensor(context, operand, writable, readable) {
59-
return context.createTensor({
60-
dataType: operand.dataType, shape: operand.shape, writable, readable
61-
})
62-
}
63-
6445
async #submit(event) {
6546
const value1 = this.shadowRoot.querySelector("#value1").value;
6647
const value2 = this.shadowRoot.querySelector("#value2").value;
6748

68-
this.#context.writeTensor(this.#tensors.inputA, new Float32Array([value1]));
69-
this.#context.writeTensor(this.#tensors.inputB, new Float32Array([value2]));
70-
71-
const inputs = {
72-
'A': this.#tensors.inputA,
73-
'B': this.#tensors.inputB
74-
};
75-
76-
const outputs = {
77-
'C': this.#tensors.outputC
78-
};
49+
await this.#program.set("A", [value1]);
50+
await this.#program.set("B", [value2]);
7951

80-
this.#context.dispatch(this.#graph, inputs, outputs);
81-
const output = await this.#context.readTensor(this.#tensors.outputC);
82-
const result = new Float32Array(output)[0];
83-
this.shadowRoot.querySelector("#result").textContent = result;
52+
this.shadowRoot.querySelector("#result").textContent = await this.#program.run();
8453
}
8554
}
8655

src/modules/webnn/webnn-program.js

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import DeviceType from "./device-type.js";
2+
import PowerOptions from "./power-options.js";
3+
4+
export class WebNNProgram {
5+
#context;
6+
#graphBuilder;
7+
#graph;
8+
#inputTensors = {};
9+
#outputTensors = {};
10+
11+
async init(deviceType = DeviceType.GPU, powerPreference = PowerOptions.DEFAULT) {
12+
const contextOptions = {
13+
deviceType,
14+
powerPreference
15+
};
16+
17+
this.#context = await navigator.ml.createContext(contextOptions);
18+
this.#graphBuilder = new MLGraphBuilder(this.#context);
19+
}
20+
21+
/**
22+
* Create tensor on context
23+
* @param {*} context
24+
* @param {*} operand
25+
* @param {*} writable
26+
* @param {*} readable
27+
* @returns
28+
*/
29+
async #createTensor(operand, writable, readable) {
30+
return await this.#context.createTensor({
31+
dataType: operand.dataType, shape: operand.shape, writable, readable
32+
})
33+
}
34+
35+
async addInputTensor(name, operand) {
36+
this.#inputTensors[name] = await this.#createTensor(operand, true, false);
37+
}
38+
39+
async addOutputTensor(name, operand) {
40+
this.#outputTensors[name] = await this.#createTensor(operand, false, true);
41+
}
42+
43+
addToGraph(action, ...args) {
44+
return this.#graphBuilder[action](...args);
45+
}
46+
47+
async set(name, values) {
48+
await this.#context.writeTensor(this.#inputTensors[name], new Float32Array(values));
49+
}
50+
51+
async build(args) {
52+
this.#graph = await this.#graphBuilder.build(args);
53+
}
54+
55+
async run() {
56+
this.#context.dispatch(this.#graph, this.#inputTensors, this.#outputTensors);
57+
58+
const outputKey = Object.keys(this.#outputTensors)[0];
59+
const output = await this.#context.readTensor(this.#outputTensors[outputKey]);
60+
return new Float32Array(output)[0];
61+
}
62+
}
63+
64+
// const program = new WebNNProgram();
65+
// program.buildGraph();
66+
// const result = program.run({a: 1, b: 2});

0 commit comments

Comments
 (0)