@@ -9,23 +9,38 @@ Fast Style Transfer
9
9
*/
10
10
11
11
import * as tf from '@tensorflow/tfjs' ;
12
+ import Video from './../utils/Video' ;
12
13
import CheckpointLoader from '../utils/checkpointLoader' ;
13
14
import { array3DToImage } from '../utils/imageUtilities' ;
14
15
15
- class StyleTransfer {
16
- constructor ( model , callback ) {
16
+ const IMAGE_SIZE = 200 ;
17
+
18
+ class StyleTransfer extends Video {
19
+ constructor ( model , videoOrCallback , cb = ( ) => { } ) {
20
+ super ( videoOrCallback , IMAGE_SIZE ) ;
17
21
this . ready = false ;
18
22
this . variableDictionary = { } ;
19
23
this . timesScalar = tf . scalar ( 150 ) ;
20
24
this . plusScalar = tf . scalar ( 255.0 / 2 ) ;
21
25
this . epsilonScalar = tf . scalar ( 1e-3 ) ;
22
- this . loadCheckpoints ( model ) . then ( ( ) => {
23
- this . ready = true ;
24
- if ( callback ) {
25
- callback ( ) ;
26
- }
26
+ this . video = null ;
27
+
28
+ let callback = cb ;
29
+ if ( typeof videoOrCallback === 'function' ) {
30
+ callback = videoOrCallback ;
31
+ }
32
+
33
+ this . loadVideo ( ) . then ( ( ) => {
34
+ this . videoReady = true ;
35
+ this . loadCheckpoints ( model ) . then ( ( ) => {
36
+ this . ready = true ;
37
+ if ( callback ) {
38
+ callback ( ) ;
39
+ }
40
+ } ) ;
27
41
} ) ;
28
42
}
43
+
29
44
async loadCheckpoints ( path ) {
30
45
const checkpointLoader = new CheckpointLoader ( path ) ;
31
46
this . variables = await checkpointLoader . getAllVariables ( ) ;
@@ -70,7 +85,19 @@ class StyleTransfer {
70
85
return y3 ;
71
86
}
72
87
73
- transfer ( input ) {
88
+ async transfer ( inputOrCallback , cb = ( ) => { } ) {
89
+ let input ;
90
+ let callback = cb ;
91
+
92
+ if ( inputOrCallback instanceof HTMLVideoElement || inputOrCallback instanceof HTMLImageElement ) {
93
+ input = inputOrCallback ;
94
+ } else if ( typeof inputOrCallback === 'object' && ( inputOrCallback . elt instanceof HTMLVideoElement || inputOrCallback . elt instanceof HTMLImageElement ) ) {
95
+ input = inputOrCallback ;
96
+ } else if ( typeof inputOrCallback === 'function' ) {
97
+ input = this . video ;
98
+ callback = inputOrCallback ;
99
+ }
100
+
74
101
const image = tf . fromPixels ( input ) ;
75
102
const result = tf . tidy ( ( ) => {
76
103
const conv1 = this . convLayer ( image , 1 , true , 0 ) ;
@@ -91,7 +118,8 @@ class StyleTransfer {
91
118
const normalized = tf . div ( clamped , tf . scalar ( 255.0 ) ) ;
92
119
return normalized ;
93
120
} ) ;
94
- return array3DToImage ( result ) ;
121
+ await tf . nextFrame ( ) ;
122
+ callback ( array3DToImage ( result ) ) ;
95
123
}
96
124
97
125
// Static Methods
0 commit comments