Skip to content

Commit ffe70d8

Browse files
committed
refactor pix2pix to support promises
1 parent 8d37fb4 commit ffe70d8

File tree

3 files changed

+61
-59
lines changed

3 files changed

+61
-59
lines changed

src/Pix2pix/index.js

Lines changed: 60 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,72 +11,47 @@ Pix2pix
1111
import * as tf from '@tensorflow/tfjs';
1212
import CheckpointLoaderPix2pix from '../utils/checkpointLoaderPix2pix';
1313
import { array3DToImage } from '../utils/imageUtilities';
14+
import callCallback from '../utils/callcallback';
1415

1516
class Pix2pix {
1617
constructor(model, callback) {
17-
this.ready = false;
18-
19-
this.loadCheckpoints(model).then(() => {
20-
this.ready = true;
21-
if (callback) {
22-
callback();
23-
}
24-
});
18+
this.ready = callCallback(this.loadCheckpoints(model), callback);
2519
}
2620

2721
async loadCheckpoints(path) {
2822
const checkpointLoader = new CheckpointLoaderPix2pix(path);
29-
this.weights = await checkpointLoader.fetchWeights();
23+
this.variables = await checkpointLoader.getAllVariables();
24+
return this;
3025
}
3126

32-
async transfer(inputElement, callback = () => {}) {
27+
async transfer(inputElement, cb) {
28+
return callCallback(this.transferInternal(inputElement), cb);
29+
}
30+
31+
async transferInternal(inputElement) {
3332
const input = tf.fromPixels(inputElement);
3433
const inputData = input.dataSync();
3534
const floatInput = tf.tensor3d(inputData, input.shape);
3635
const normalizedInput = tf.div(floatInput, tf.scalar(255));
3736

38-
function preprocess(inputPreproc) {
39-
return tf.sub(tf.mul(inputPreproc, tf.scalar(2)), tf.scalar(1));
40-
}
41-
42-
function deprocess(inputDeproc) {
43-
return tf.div(tf.add(inputDeproc, tf.scalar(1)), tf.scalar(2));
44-
}
45-
46-
function batchnorm(inputBat, scale, offset) {
47-
const moments = tf.moments(inputBat, [0, 1]);
48-
const varianceEpsilon = 1e-5;
49-
return tf.batchNormalization(inputBat, moments.mean, moments.variance, varianceEpsilon, scale, offset);
50-
}
51-
52-
function conv2d(inputCon, filterCon) {
53-
return tf.conv2d(inputCon, filterCon, [2, 2], 'same');
54-
}
55-
56-
function deconv2d(inputDeconv, filterDeconv, biasDecon) {
57-
const convolved = tf.conv2dTranspose(inputDeconv, filterDeconv, [inputDeconv.shape[0] * 2, inputDeconv.shape[1] * 2, filterDeconv.shape[2]], [2, 2], 'same');
58-
const biased = tf.add(convolved, biasDecon);
59-
return biased;
60-
}
61-
62-
const result = tf.tidy(() => {
63-
const preprocessedInput = preprocess(normalizedInput);
37+
const result = array3DToImage(tf.tidy(() => {
38+
const preprocessedInput = Pix2pix.preprocess(normalizedInput);
6439
const layers = [];
65-
let filter = this.weights['generator/encoder_1/conv2d/kernel'];
66-
let bias = this.weights['generator/encoder_1/conv2d/bias'];
67-
let convolved = conv2d(preprocessedInput, filter, bias);
40+
let filter = this.variables['generator/encoder_1/conv2d/kernel'];
41+
let bias = this.variables['generator/encoder_1/conv2d/bias'];
42+
let convolved = Pix2pix.conv2d(preprocessedInput, filter, bias);
6843
layers.push(convolved);
6944

7045
for (let i = 2; i <= 8; i += 1) {
7146
const scope = `generator/encoder_${i.toString()}`;
72-
filter = this.weights[`${scope}/conv2d/kernel`];
73-
const bias2 = this.weights[`${scope}/conv2d/bias`];
47+
filter = this.variables[`${scope}/conv2d/kernel`];
48+
const bias2 = this.variables[`${scope}/conv2d/bias`];
7449
const layerInput = layers[layers.length - 1];
7550
const rectified = tf.leakyRelu(layerInput, 0.2);
76-
convolved = conv2d(rectified, filter, bias2);
77-
const scale = this.weights[`${scope}/batch_normalization/gamma`];
78-
const offset = this.weights[`${scope}/batch_normalization/beta`];
79-
const normalized = batchnorm(convolved, scale, offset);
51+
convolved = Pix2pix.conv2d(rectified, filter, bias2);
52+
const scale = this.variables[`${scope}/batch_normalization/gamma`];
53+
const offset = this.variables[`${scope}/batch_normalization/beta`];
54+
const normalized = Pix2pix.batchnorm(convolved, scale, offset);
8055
layers.push(normalized);
8156
}
8257

@@ -90,33 +65,60 @@ class Pix2pix {
9065
}
9166
const rectified = tf.relu(layerInput);
9267
const scope = `generator/decoder_${i.toString()}`;
93-
filter = this.weights[`${scope}/conv2d_transpose/kernel`];
94-
bias = this.weights[`${scope}/conv2d_transpose/bias`];
95-
convolved = deconv2d(rectified, filter, bias);
96-
const scale = this.weights[`${scope}/batch_normalization/gamma`];
97-
const offset = this.weights[`${scope}/batch_normalization/beta`];
98-
const normalized = batchnorm(convolved, scale, offset);
68+
filter = this.variables[`${scope}/conv2d_transpose/kernel`];
69+
bias = this.variables[`${scope}/conv2d_transpose/bias`];
70+
convolved = Pix2pix.deconv2d(rectified, filter, bias);
71+
const scale = this.variables[`${scope}/batch_normalization/gamma`];
72+
const offset = this.variables[`${scope}/batch_normalization/beta`];
73+
const normalized = Pix2pix.batchnorm(convolved, scale, offset);
9974
layers.push(normalized);
10075
}
10176

10277
const layerInput = tf.concat([layers[layers.length - 1], layers[0]], 2);
10378
let rectified2 = tf.relu(layerInput);
104-
filter = this.weights['generator/decoder_1/conv2d_transpose/kernel'];
105-
const bias3 = this.weights['generator/decoder_1/conv2d_transpose/bias'];
106-
convolved = deconv2d(rectified2, filter, bias3);
79+
filter = this.variables['generator/decoder_1/conv2d_transpose/kernel'];
80+
const bias3 = this.variables['generator/decoder_1/conv2d_transpose/bias'];
81+
convolved = Pix2pix.deconv2d(rectified2, filter, bias3);
10782
rectified2 = tf.tanh(convolved);
10883
layers.push(rectified2);
10984

11085
const output = layers[layers.length - 1];
111-
const deprocessedOutput = deprocess(output);
86+
const deprocessedOutput = Pix2pix.deprocess(output);
11287
return deprocessedOutput;
113-
});
88+
}));
11489

11590
await tf.nextFrame();
116-
callback(array3DToImage(result));
91+
return result;
92+
}
93+
94+
static preprocess(inputPreproc) {
95+
return tf.sub(tf.mul(inputPreproc, tf.scalar(2)), tf.scalar(1));
96+
}
97+
98+
static deprocess(inputDeproc) {
99+
return tf.div(tf.add(inputDeproc, tf.scalar(1)), tf.scalar(2));
100+
}
101+
102+
static batchnorm(inputBat, scale, offset) {
103+
const moments = tf.moments(inputBat, [0, 1]);
104+
const varianceEpsilon = 1e-5;
105+
return tf.batchNormalization(inputBat, moments.mean, moments.variance, varianceEpsilon, scale, offset);
106+
}
107+
108+
static conv2d(inputCon, filterCon) {
109+
return tf.conv2d(inputCon, filterCon, [2, 2], 'same');
110+
}
111+
112+
static deconv2d(inputDeconv, filterDeconv, biasDecon) {
113+
const convolved = tf.conv2dTranspose(inputDeconv, filterDeconv, [inputDeconv.shape[0] * 2, inputDeconv.shape[1] * 2, filterDeconv.shape[2]], [2, 2], 'same');
114+
const biased = tf.add(convolved, biasDecon);
115+
return biased;
117116
}
118117
}
119118

120-
const pix2pix = (model, callback = () => {}) => new Pix2pix(model, callback);
119+
const pix2pix = (model, callback) => {
120+
const instance = new Pix2pix(model, callback);
121+
return callback ? instance : instance.ready;
122+
};
121123

122124
export default pix2pix;

src/Pix2pix/index_test.js

Whitespace-only changes.

src/utils/checkpointLoaderPix2pix.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export default class CheckpointLoaderPix2pix {
77
this.urlPath = urlPath;
88
}
99

10-
fetchWeights() {
10+
getAllVariables() {
1111
return new Promise((resolve, reject) => {
1212
const weightsCache = {};
1313
if (this.urlPath in weightsCache) {

0 commit comments

Comments
 (0)