@@ -11,72 +11,47 @@ Pix2pix
11
11
import * as tf from '@tensorflow/tfjs' ;
12
12
import CheckpointLoaderPix2pix from '../utils/checkpointLoaderPix2pix' ;
13
13
import { array3DToImage } from '../utils/imageUtilities' ;
14
+ import callCallback from '../utils/callcallback' ;
14
15
15
16
class Pix2pix {
16
17
constructor ( model , callback ) {
17
- this . ready = false ;
18
-
19
- this . loadCheckpoints ( model ) . then ( ( ) => {
20
- this . ready = true ;
21
- if ( callback ) {
22
- callback ( ) ;
23
- }
24
- } ) ;
18
+ this . ready = callCallback ( this . loadCheckpoints ( model ) , callback ) ;
25
19
}
26
20
27
21
async loadCheckpoints ( path ) {
28
22
const checkpointLoader = new CheckpointLoaderPix2pix ( path ) ;
29
- this . weights = await checkpointLoader . fetchWeights ( ) ;
23
+ this . variables = await checkpointLoader . getAllVariables ( ) ;
24
+ return this ;
30
25
}
31
26
32
- async transfer ( inputElement , callback = ( ) => { } ) {
27
+ async transfer ( inputElement , cb ) {
28
+ return callCallback ( this . transferInternal ( inputElement ) , cb ) ;
29
+ }
30
+
31
+ async transferInternal ( inputElement ) {
33
32
const input = tf . fromPixels ( inputElement ) ;
34
33
const inputData = input . dataSync ( ) ;
35
34
const floatInput = tf . tensor3d ( inputData , input . shape ) ;
36
35
const normalizedInput = tf . div ( floatInput , tf . scalar ( 255 ) ) ;
37
36
38
- function preprocess ( inputPreproc ) {
39
- return tf . sub ( tf . mul ( inputPreproc , tf . scalar ( 2 ) ) , tf . scalar ( 1 ) ) ;
40
- }
41
-
42
- function deprocess ( inputDeproc ) {
43
- return tf . div ( tf . add ( inputDeproc , tf . scalar ( 1 ) ) , tf . scalar ( 2 ) ) ;
44
- }
45
-
46
- function batchnorm ( inputBat , scale , offset ) {
47
- const moments = tf . moments ( inputBat , [ 0 , 1 ] ) ;
48
- const varianceEpsilon = 1e-5 ;
49
- return tf . batchNormalization ( inputBat , moments . mean , moments . variance , varianceEpsilon , scale , offset ) ;
50
- }
51
-
52
- function conv2d ( inputCon , filterCon ) {
53
- return tf . conv2d ( inputCon , filterCon , [ 2 , 2 ] , 'same' ) ;
54
- }
55
-
56
- function deconv2d ( inputDeconv , filterDeconv , biasDecon ) {
57
- const convolved = tf . conv2dTranspose ( inputDeconv , filterDeconv , [ inputDeconv . shape [ 0 ] * 2 , inputDeconv . shape [ 1 ] * 2 , filterDeconv . shape [ 2 ] ] , [ 2 , 2 ] , 'same' ) ;
58
- const biased = tf . add ( convolved , biasDecon ) ;
59
- return biased ;
60
- }
61
-
62
- const result = tf . tidy ( ( ) => {
63
- const preprocessedInput = preprocess ( normalizedInput ) ;
37
+ const result = array3DToImage ( tf . tidy ( ( ) => {
38
+ const preprocessedInput = Pix2pix . preprocess ( normalizedInput ) ;
64
39
const layers = [ ] ;
65
- let filter = this . weights [ 'generator/encoder_1/conv2d/kernel' ] ;
66
- let bias = this . weights [ 'generator/encoder_1/conv2d/bias' ] ;
67
- let convolved = conv2d ( preprocessedInput , filter , bias ) ;
40
+ let filter = this . variables [ 'generator/encoder_1/conv2d/kernel' ] ;
41
+ let bias = this . variables [ 'generator/encoder_1/conv2d/bias' ] ;
42
+ let convolved = Pix2pix . conv2d ( preprocessedInput , filter , bias ) ;
68
43
layers . push ( convolved ) ;
69
44
70
45
for ( let i = 2 ; i <= 8 ; i += 1 ) {
71
46
const scope = `generator/encoder_${ i . toString ( ) } ` ;
72
- filter = this . weights [ `${ scope } /conv2d/kernel` ] ;
73
- const bias2 = this . weights [ `${ scope } /conv2d/bias` ] ;
47
+ filter = this . variables [ `${ scope } /conv2d/kernel` ] ;
48
+ const bias2 = this . variables [ `${ scope } /conv2d/bias` ] ;
74
49
const layerInput = layers [ layers . length - 1 ] ;
75
50
const rectified = tf . leakyRelu ( layerInput , 0.2 ) ;
76
- convolved = conv2d ( rectified , filter , bias2 ) ;
77
- const scale = this . weights [ `${ scope } /batch_normalization/gamma` ] ;
78
- const offset = this . weights [ `${ scope } /batch_normalization/beta` ] ;
79
- const normalized = batchnorm ( convolved , scale , offset ) ;
51
+ convolved = Pix2pix . conv2d ( rectified , filter , bias2 ) ;
52
+ const scale = this . variables [ `${ scope } /batch_normalization/gamma` ] ;
53
+ const offset = this . variables [ `${ scope } /batch_normalization/beta` ] ;
54
+ const normalized = Pix2pix . batchnorm ( convolved , scale , offset ) ;
80
55
layers . push ( normalized ) ;
81
56
}
82
57
@@ -90,33 +65,60 @@ class Pix2pix {
90
65
}
91
66
const rectified = tf . relu ( layerInput ) ;
92
67
const scope = `generator/decoder_${ i . toString ( ) } ` ;
93
- filter = this . weights [ `${ scope } /conv2d_transpose/kernel` ] ;
94
- bias = this . weights [ `${ scope } /conv2d_transpose/bias` ] ;
95
- convolved = deconv2d ( rectified , filter , bias ) ;
96
- const scale = this . weights [ `${ scope } /batch_normalization/gamma` ] ;
97
- const offset = this . weights [ `${ scope } /batch_normalization/beta` ] ;
98
- const normalized = batchnorm ( convolved , scale , offset ) ;
68
+ filter = this . variables [ `${ scope } /conv2d_transpose/kernel` ] ;
69
+ bias = this . variables [ `${ scope } /conv2d_transpose/bias` ] ;
70
+ convolved = Pix2pix . deconv2d ( rectified , filter , bias ) ;
71
+ const scale = this . variables [ `${ scope } /batch_normalization/gamma` ] ;
72
+ const offset = this . variables [ `${ scope } /batch_normalization/beta` ] ;
73
+ const normalized = Pix2pix . batchnorm ( convolved , scale , offset ) ;
99
74
layers . push ( normalized ) ;
100
75
}
101
76
102
77
const layerInput = tf . concat ( [ layers [ layers . length - 1 ] , layers [ 0 ] ] , 2 ) ;
103
78
let rectified2 = tf . relu ( layerInput ) ;
104
- filter = this . weights [ 'generator/decoder_1/conv2d_transpose/kernel' ] ;
105
- const bias3 = this . weights [ 'generator/decoder_1/conv2d_transpose/bias' ] ;
106
- convolved = deconv2d ( rectified2 , filter , bias3 ) ;
79
+ filter = this . variables [ 'generator/decoder_1/conv2d_transpose/kernel' ] ;
80
+ const bias3 = this . variables [ 'generator/decoder_1/conv2d_transpose/bias' ] ;
81
+ convolved = Pix2pix . deconv2d ( rectified2 , filter , bias3 ) ;
107
82
rectified2 = tf . tanh ( convolved ) ;
108
83
layers . push ( rectified2 ) ;
109
84
110
85
const output = layers [ layers . length - 1 ] ;
111
- const deprocessedOutput = deprocess ( output ) ;
86
+ const deprocessedOutput = Pix2pix . deprocess ( output ) ;
112
87
return deprocessedOutput ;
113
- } ) ;
88
+ } ) ) ;
114
89
115
90
await tf . nextFrame ( ) ;
116
- callback ( array3DToImage ( result ) ) ;
91
+ return result ;
92
+ }
93
+
94
+ static preprocess ( inputPreproc ) {
95
+ return tf . sub ( tf . mul ( inputPreproc , tf . scalar ( 2 ) ) , tf . scalar ( 1 ) ) ;
96
+ }
97
+
98
+ static deprocess ( inputDeproc ) {
99
+ return tf . div ( tf . add ( inputDeproc , tf . scalar ( 1 ) ) , tf . scalar ( 2 ) ) ;
100
+ }
101
+
102
+ static batchnorm ( inputBat , scale , offset ) {
103
+ const moments = tf . moments ( inputBat , [ 0 , 1 ] ) ;
104
+ const varianceEpsilon = 1e-5 ;
105
+ return tf . batchNormalization ( inputBat , moments . mean , moments . variance , varianceEpsilon , scale , offset ) ;
106
+ }
107
+
108
+ static conv2d ( inputCon , filterCon ) {
109
+ return tf . conv2d ( inputCon , filterCon , [ 2 , 2 ] , 'same' ) ;
110
+ }
111
+
112
+ static deconv2d ( inputDeconv , filterDeconv , biasDecon ) {
113
+ const convolved = tf . conv2dTranspose ( inputDeconv , filterDeconv , [ inputDeconv . shape [ 0 ] * 2 , inputDeconv . shape [ 1 ] * 2 , filterDeconv . shape [ 2 ] ] , [ 2 , 2 ] , 'same' ) ;
114
+ const biased = tf . add ( convolved , biasDecon ) ;
115
+ return biased ;
117
116
}
118
117
}
119
118
120
- const pix2pix = ( model , callback = ( ) => { } ) => new Pix2pix ( model , callback ) ;
119
+ const pix2pix = ( model , callback ) => {
120
+ const instance = new Pix2pix ( model , callback ) ;
121
+ return callback ? instance : instance . ready ;
122
+ } ;
121
123
122
124
export default pix2pix ;
0 commit comments