Skip to content

Commit 6b36c63

Browse files
committed
added next frame to style transfer, improving performance
1 parent b470b8b commit 6b36c63

File tree

1 file changed

+37
-9
lines changed

1 file changed

+37
-9
lines changed

src/StyleTransfer/index.js

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,38 @@ Fast Style Transfer
99
*/
1010

1111
import * as tf from '@tensorflow/tfjs';
12+
import Video from './../utils/Video';
1213
import CheckpointLoader from '../utils/checkpointLoader';
1314
import { array3DToImage } from '../utils/imageUtilities';
1415

15-
class StyleTransfer {
16-
constructor(model, callback) {
16+
const IMAGE_SIZE = 200;
17+
18+
class StyleTransfer extends Video {
19+
constructor(model, videoOrCallback, cb = () => {}) {
20+
super(videoOrCallback, IMAGE_SIZE);
1721
this.ready = false;
1822
this.variableDictionary = {};
1923
this.timesScalar = tf.scalar(150);
2024
this.plusScalar = tf.scalar(255.0 / 2);
2125
this.epsilonScalar = tf.scalar(1e-3);
22-
this.loadCheckpoints(model).then(() => {
23-
this.ready = true;
24-
if (callback) {
25-
callback();
26-
}
26+
this.video = null;
27+
28+
let callback = cb;
29+
if (typeof videoOrCallback === 'function') {
30+
callback = videoOrCallback;
31+
}
32+
33+
this.loadVideo().then(() => {
34+
this.videoReady = true;
35+
this.loadCheckpoints(model).then(() => {
36+
this.ready = true;
37+
if (callback) {
38+
callback();
39+
}
40+
});
2741
});
2842
}
43+
2944
async loadCheckpoints(path) {
3045
const checkpointLoader = new CheckpointLoader(path);
3146
this.variables = await checkpointLoader.getAllVariables();
@@ -70,7 +85,19 @@ class StyleTransfer {
7085
return y3;
7186
}
7287

73-
transfer(input) {
88+
async transfer(inputOrCallback, cb = () => {}) {
89+
let input;
90+
let callback = cb;
91+
92+
if (inputOrCallback instanceof HTMLVideoElement || inputOrCallback instanceof HTMLImageElement) {
93+
input = inputOrCallback;
94+
} else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLVideoElement || inputOrCallback.elt instanceof HTMLImageElement)) {
95+
input = inputOrCallback;
96+
} else if (typeof inputOrCallback === 'function') {
97+
input = this.video;
98+
callback = inputOrCallback;
99+
}
100+
74101
const image = tf.fromPixels(input);
75102
const result = tf.tidy(() => {
76103
const conv1 = this.convLayer(image, 1, true, 0);
@@ -91,7 +118,8 @@ class StyleTransfer {
91118
const normalized = tf.div(clamped, tf.scalar(255.0));
92119
return normalized;
93120
});
94-
return array3DToImage(result);
121+
await tf.nextFrame();
122+
callback(array3DToImage(result));
95123
}
96124

97125
// Static Methods

0 commit comments

Comments
 (0)