Skip to content

Commit c55bc68

Browse files
authored
Merge pull request #162 from ml5js/pix2pix
Added pix2pix with Edges2Pikachu model
2 parents c8c58ff + b5d008e commit c55bc68

File tree

5 files changed

+195
-1
lines changed

5 files changed

+195
-1
lines changed

package-lock.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/Pix2pix/index.js

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
/* eslint max-len: "off" */
7+
/*
8+
Pix2pix
9+
*/
10+
11+
import * as tf from '@tensorflow/tfjs';
12+
import CheckpointLoaderPix2pix from '../utils/checkpointLoaderPix2pix';
13+
import { array3DToImage } from '../utils/imageUtilities';
14+
import callCallback from '../utils/callcallback';
15+
16+
class Pix2pix {
17+
constructor(model, callback) {
18+
this.ready = callCallback(this.loadCheckpoints(model), callback);
19+
}
20+
21+
async loadCheckpoints(path) {
22+
const checkpointLoader = new CheckpointLoaderPix2pix(path);
23+
this.variables = await checkpointLoader.getAllVariables();
24+
return this;
25+
}
26+
27+
async transfer(inputElement, cb) {
28+
return callCallback(this.transferInternal(inputElement), cb);
29+
}
30+
31+
async transferInternal(inputElement) {
32+
const input = tf.fromPixels(inputElement);
33+
const inputData = input.dataSync();
34+
const floatInput = tf.tensor3d(inputData, input.shape);
35+
const normalizedInput = tf.div(floatInput, tf.scalar(255));
36+
37+
const result = array3DToImage(tf.tidy(() => {
38+
const preprocessedInput = Pix2pix.preprocess(normalizedInput);
39+
const layers = [];
40+
let filter = this.variables['generator/encoder_1/conv2d/kernel'];
41+
let bias = this.variables['generator/encoder_1/conv2d/bias'];
42+
let convolved = Pix2pix.conv2d(preprocessedInput, filter, bias);
43+
layers.push(convolved);
44+
45+
for (let i = 2; i <= 8; i += 1) {
46+
const scope = `generator/encoder_${i.toString()}`;
47+
filter = this.variables[`${scope}/conv2d/kernel`];
48+
const bias2 = this.variables[`${scope}/conv2d/bias`];
49+
const layerInput = layers[layers.length - 1];
50+
const rectified = tf.leakyRelu(layerInput, 0.2);
51+
convolved = Pix2pix.conv2d(rectified, filter, bias2);
52+
const scale = this.variables[`${scope}/batch_normalization/gamma`];
53+
const offset = this.variables[`${scope}/batch_normalization/beta`];
54+
const normalized = Pix2pix.batchnorm(convolved, scale, offset);
55+
layers.push(normalized);
56+
}
57+
58+
for (let i = 8; i >= 2; i -= 1) {
59+
let layerInput;
60+
if (i === 8) {
61+
layerInput = layers[layers.length - 1];
62+
} else {
63+
const skipLayer = i - 1;
64+
layerInput = tf.concat([layers[layers.length - 1], layers[skipLayer]], 2);
65+
}
66+
const rectified = tf.relu(layerInput);
67+
const scope = `generator/decoder_${i.toString()}`;
68+
filter = this.variables[`${scope}/conv2d_transpose/kernel`];
69+
bias = this.variables[`${scope}/conv2d_transpose/bias`];
70+
convolved = Pix2pix.deconv2d(rectified, filter, bias);
71+
const scale = this.variables[`${scope}/batch_normalization/gamma`];
72+
const offset = this.variables[`${scope}/batch_normalization/beta`];
73+
const normalized = Pix2pix.batchnorm(convolved, scale, offset);
74+
layers.push(normalized);
75+
}
76+
77+
const layerInput = tf.concat([layers[layers.length - 1], layers[0]], 2);
78+
let rectified2 = tf.relu(layerInput);
79+
filter = this.variables['generator/decoder_1/conv2d_transpose/kernel'];
80+
const bias3 = this.variables['generator/decoder_1/conv2d_transpose/bias'];
81+
convolved = Pix2pix.deconv2d(rectified2, filter, bias3);
82+
rectified2 = tf.tanh(convolved);
83+
layers.push(rectified2);
84+
85+
const output = layers[layers.length - 1];
86+
const deprocessedOutput = Pix2pix.deprocess(output);
87+
return deprocessedOutput;
88+
}));
89+
90+
await tf.nextFrame();
91+
return result;
92+
}
93+
94+
static preprocess(inputPreproc) {
95+
return tf.sub(tf.mul(inputPreproc, tf.scalar(2)), tf.scalar(1));
96+
}
97+
98+
static deprocess(inputDeproc) {
99+
return tf.div(tf.add(inputDeproc, tf.scalar(1)), tf.scalar(2));
100+
}
101+
102+
static batchnorm(inputBat, scale, offset) {
103+
const moments = tf.moments(inputBat, [0, 1]);
104+
const varianceEpsilon = 1e-5;
105+
return tf.batchNormalization(inputBat, moments.mean, moments.variance, varianceEpsilon, scale, offset);
106+
}
107+
108+
static conv2d(inputCon, filterCon) {
109+
return tf.conv2d(inputCon, filterCon, [2, 2], 'same');
110+
}
111+
112+
static deconv2d(inputDeconv, filterDeconv, biasDecon) {
113+
const convolved = tf.conv2dTranspose(inputDeconv, filterDeconv, [inputDeconv.shape[0] * 2, inputDeconv.shape[1] * 2, filterDeconv.shape[2]], [2, 2], 'same');
114+
const biased = tf.add(convolved, biasDecon);
115+
return biased;
116+
}
117+
}
118+
119+
const pix2pix = (model, callback) => {
120+
const instance = new Pix2pix(model, callback);
121+
return callback ? instance : instance.ready;
122+
};
123+
124+
export default pix2pix;

src/Pix2pix/index_test.js

Whitespace-only changes.

src/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import poseNet from './PoseNet';
1313
import * as imageUtils from './utils/imageUtilities';
1414
import styleTransfer from './StyleTransfer/';
1515
import LSTMGenerator from './LSTM/';
16+
import pix2pix from './Pix2pix/';
1617

1718
module.exports = {
1819
imageClassifier,
@@ -23,6 +24,7 @@ module.exports = {
2324
styleTransfer,
2425
poseNet,
2526
LSTMGenerator,
27+
pix2pix,
2628
...imageUtils,
2729
tf,
2830
};
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* eslint max-len: "off" */
2+
3+
import * as tf from '@tensorflow/tfjs';
4+
5+
export default class CheckpointLoaderPix2pix {
6+
constructor(urlPath) {
7+
this.urlPath = urlPath;
8+
}
9+
10+
getAllVariables() {
11+
return new Promise((resolve, reject) => {
12+
const weightsCache = {};
13+
if (this.urlPath in weightsCache) {
14+
resolve(weightsCache[this.urlPath]);
15+
return;
16+
}
17+
18+
const xhr = new XMLHttpRequest();
19+
xhr.open('GET', this.urlPath, true);
20+
xhr.responseType = 'arraybuffer';
21+
xhr.onload = () => {
22+
if (xhr.status !== 200) {
23+
reject(new Error('missing model'));
24+
return;
25+
}
26+
const buf = xhr.response;
27+
if (!buf) {
28+
reject(new Error('invalid arraybuffer'));
29+
return;
30+
}
31+
32+
const parts = [];
33+
let offset = 0;
34+
while (offset < buf.byteLength) {
35+
const b = new Uint8Array(buf.slice(offset, offset + 4));
36+
offset += 4;
37+
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
38+
parts.push(buf.slice(offset, offset + len));
39+
offset += len;
40+
}
41+
42+
const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
43+
const index = new Float32Array(parts[1]);
44+
const encoded = new Uint8Array(parts[2]);
45+
46+
// decode using index
47+
const arr = new Float32Array(encoded.length);
48+
for (let i = 0; i < arr.length; i += 1) {
49+
arr[i] = index[encoded[i]];
50+
}
51+
52+
const weights = {};
53+
offset = 0;
54+
for (let i = 0; i < shapes.length; i += 1) {
55+
const { shape } = shapes[i];
56+
const size = shape.reduce((total, num) => total * num);
57+
const values = arr.slice(offset, offset + size);
58+
const tfarr = tf.tensor1d(values, 'float32');
59+
weights[shapes[i].name] = tfarr.reshape(shape);
60+
offset += size;
61+
}
62+
weightsCache[this.urlPath] = weights;
63+
resolve(weights);
64+
};
65+
xhr.send(null);
66+
});
67+
}
68+
}

0 commit comments

Comments
 (0)