Skip to content

Commit 12a5c1d

Browse files
committed
added pix2pix class, pix2pix.transfer method, a checkpointsloader for pix2pix model
1 parent 0a4a433 commit 12a5c1d

File tree

4 files changed

+193
-1
lines changed

4 files changed

+193
-1
lines changed

package-lock.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/Pix2pix/index.js

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
/* eslint max-len: "off" */
7+
/*
8+
Pix2pix
9+
*/
10+
11+
import * as tf from '@tensorflow/tfjs';
12+
import CheckpointLoaderPix2pix from '../utils/checkpointLoaderPix2pix';
13+
import { array3DToImage } from '../utils/imageUtilities';
14+
15+
class Pix2pix {
16+
constructor(model, callback) {
17+
this.ready = false;
18+
19+
this.loadCheckpoints(model).then(() => {
20+
this.ready = true;
21+
if (callback) {
22+
callback();
23+
}
24+
});
25+
}
26+
27+
async loadCheckpoints(path) {
28+
const checkpointLoader = new CheckpointLoaderPix2pix(path);
29+
this.weights = await checkpointLoader.fetchWeights();
30+
}
31+
32+
async transfer(inputElement, callback = () => {}) {
33+
const input = tf.fromPixels(inputElement);
34+
const inputData = input.dataSync();
35+
const floatInput = tf.tensor3d(inputData, input.shape);
36+
const normalizedInput = tf.div(floatInput, tf.scalar(255));
37+
38+
function preprocess(inputPreproc) {
39+
return tf.sub(tf.mul(inputPreproc, tf.scalar(2)), tf.scalar(1));
40+
}
41+
42+
function deprocess(inputDeproc) {
43+
return tf.div(tf.add(inputDeproc, tf.scalar(1)), tf.scalar(2));
44+
}
45+
46+
function batchnorm(inputBat, scale, offset) {
47+
const moments = tf.moments(inputBat, [0, 1]);
48+
const varianceEpsilon = 1e-5;
49+
return tf.batchNormalization(inputBat, moments.mean, moments.variance, varianceEpsilon, scale, offset);
50+
}
51+
52+
function conv2d(inputCon, filterCon) {
53+
return tf.conv2d(inputCon, filterCon, [2, 2], 'same');
54+
}
55+
56+
function deconv2d(inputDeconv, filterDeconv, biasDecon) {
57+
const convolved = tf.conv2dTranspose(inputDeconv, filterDeconv, [inputDeconv.shape[0] * 2, inputDeconv.shape[1] * 2, filterDeconv.shape[2]], [2, 2], 'same');
58+
const biased = tf.add(convolved, biasDecon);
59+
return biased;
60+
}
61+
62+
const result = tf.tidy(() => {
63+
const preprocessedInput = preprocess(normalizedInput);
64+
const layers = [];
65+
let filter = this.weights['generator/encoder_1/conv2d/kernel'];
66+
let bias = this.weights['generator/encoder_1/conv2d/bias'];
67+
let convolved = conv2d(preprocessedInput, filter, bias);
68+
layers.push(convolved);
69+
70+
for (let i = 2; i <= 8; i += 1) {
71+
const scope = `generator/encoder_${i.toString()}`;
72+
filter = this.weights[`${scope}/conv2d/kernel`];
73+
const bias2 = this.weights[`${scope}/conv2d/bias`];
74+
const layerInput = layers[layers.length - 1];
75+
const rectified = tf.leakyRelu(layerInput, 0.2);
76+
convolved = conv2d(rectified, filter, bias2);
77+
const scale = this.weights[`${scope}/batch_normalization/gamma`];
78+
const offset = this.weights[`${scope}/batch_normalization/beta`];
79+
const normalized = batchnorm(convolved, scale, offset);
80+
layers.push(normalized);
81+
}
82+
83+
for (let i = 8; i >= 2; i -= 1) {
84+
let layerInput;
85+
if (i === 8) {
86+
layerInput = layers[layers.length - 1];
87+
} else {
88+
const skipLayer = i - 1;
89+
layerInput = tf.concat([layers[layers.length - 1], layers[skipLayer]], 2);
90+
}
91+
const rectified = tf.relu(layerInput);
92+
const scope = `generator/decoder_${i.toString()}`;
93+
filter = this.weights[`${scope}/conv2d_transpose/kernel`];
94+
bias = this.weights[`${scope}/conv2d_transpose/bias`];
95+
convolved = deconv2d(rectified, filter, bias);
96+
const scale = this.weights[`${scope}/batch_normalization/gamma`];
97+
const offset = this.weights[`${scope}/batch_normalization/beta`];
98+
const normalized = batchnorm(convolved, scale, offset);
99+
layers.push(normalized);
100+
}
101+
102+
const layerInput = tf.concat([layers[layers.length - 1], layers[0]], 2);
103+
let rectified2 = tf.relu(layerInput);
104+
filter = this.weights['generator/decoder_1/conv2d_transpose/kernel'];
105+
const bias3 = this.weights['generator/decoder_1/conv2d_transpose/bias'];
106+
convolved = deconv2d(rectified2, filter, bias3);
107+
rectified2 = tf.tanh(convolved);
108+
layers.push(rectified2);
109+
110+
const output = layers[layers.length - 1];
111+
const deprocessedOutput = deprocess(output);
112+
return deprocessedOutput;
113+
});
114+
115+
await tf.nextFrame();
116+
callback(array3DToImage(result));
117+
}
118+
}
119+
120+
const pix2pix = (model, callback = () => {}) => new Pix2pix(model, callback);
121+
122+
export default pix2pix;

src/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import poseNet from './PoseNet';
1313
import * as imageUtils from './utils/imageUtilities';
1414
import styleTransfer from './StyleTransfer/';
1515
import LSTMGenerator from './LSTM/';
16+
import pix2pix from './Pix2pix/';
1617

1718
module.exports = {
1819
imageClassifier,
@@ -23,6 +24,7 @@ module.exports = {
2324
styleTransfer,
2425
poseNet,
2526
LSTMGenerator,
27+
pix2pix,
2628
...imageUtils,
2729
tf,
2830
};
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* eslint max-len: "off" */
2+
3+
import * as tf from '@tensorflow/tfjs';
4+
5+
export default class CheckpointLoaderPix2pix {
6+
constructor(urlPath) {
7+
this.urlPath = urlPath;
8+
}
9+
10+
fetchWeights() {
11+
return new Promise((resolve, reject) => {
12+
const weightsCache = {};
13+
if (this.urlPath in weightsCache) {
14+
resolve(weightsCache[this.urlPath]);
15+
return;
16+
}
17+
18+
const xhr = new XMLHttpRequest();
19+
xhr.open('GET', this.urlPath, true);
20+
xhr.responseType = 'arraybuffer';
21+
xhr.onload = () => {
22+
if (xhr.status !== 200) {
23+
reject(new Error('missing model'));
24+
return;
25+
}
26+
const buf = xhr.response;
27+
if (!buf) {
28+
reject(new Error('invalid arraybuffer'));
29+
return;
30+
}
31+
32+
const parts = [];
33+
let offset = 0;
34+
while (offset < buf.byteLength) {
35+
const b = new Uint8Array(buf.slice(offset, offset + 4));
36+
offset += 4;
37+
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
38+
parts.push(buf.slice(offset, offset + len));
39+
offset += len;
40+
}
41+
42+
const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
43+
const index = new Float32Array(parts[1]);
44+
const encoded = new Uint8Array(parts[2]);
45+
46+
// decode using index
47+
const arr = new Float32Array(encoded.length);
48+
for (let i = 0; i < arr.length; i += 1) {
49+
arr[i] = index[encoded[i]];
50+
}
51+
52+
const weights = {};
53+
offset = 0;
54+
for (let i = 0; i < shapes.length; i += 1) {
55+
const { shape } = shapes[i];
56+
const size = shape.reduce((total, num) => total * num);
57+
const values = arr.slice(offset, offset + size);
58+
const tfarr = tf.tensor1d(values, 'float32');
59+
weights[shapes[i].name] = tfarr.reshape(shape);
60+
offset += size;
61+
}
62+
weightsCache[this.urlPath] = weights;
63+
resolve(weights);
64+
};
65+
xhr.send(null);
66+
});
67+
}
68+
}

0 commit comments

Comments
 (0)