Skip to content

Commit 6d18b81

Browse files
reiinakanocvalenzuela
authored andcommitted
SketchRNN (#189)
* init sketchrnn * strokes * add models * stylefixes * some more bugfixes * comments * update dependencies
1 parent a354dda commit 6d18b81

File tree

6 files changed

+337
-61
lines changed

6 files changed

+337
-61
lines changed

package-lock.json

Lines changed: 59 additions & 26 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,10 @@
8686
]
8787
},
8888
"dependencies": {
89-
"@tensorflow-models/mobilenet": "0.1.0",
90-
"@tensorflow-models/posenet": "0.1.3",
91-
"@tensorflow/tfjs": "0.11.4",
89+
"@magenta/sketch": "^0.1.2",
90+
"@tensorflow-models/mobilenet": "0.2.2",
91+
"@tensorflow-models/posenet": "0.2.2",
92+
"@tensorflow/tfjs": "0.13.0",
9293
"events": "^3.0.0"
9394
}
9495
}

src/SketchRNN/index.js

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
/* eslint prefer-destructuring: ['error', {AssignmentExpression: {array: false}}] */
7+
/* eslint no-await-in-loop: 'off' */
8+
/*
9+
SketchRNN
10+
*/
11+
12+
import * as ms from '@magenta/sketch';
13+
import callCallback from '../utils/callcallback';
14+
import modelPaths from './models';
15+
16+
const PATH_START_LARGE = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/';
17+
const PATH_START_SMALL = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/models/';
18+
const PATH_END = '.gen.json';
19+
20+
class SketchRNN {
21+
constructor(model, callback, large = true) {
22+
let checkpointUrl = model;
23+
if (modelPaths.has(checkpointUrl)) {
24+
checkpointUrl = (large ? PATH_START_LARGE : PATH_START_SMALL) + checkpointUrl + PATH_END;
25+
}
26+
this.defaults = {
27+
temperature: 0.65,
28+
pixelFactor: 3.0,
29+
};
30+
this.model = new ms.SketchRNN(checkpointUrl);
31+
this.penState = this.model.zeroInput();
32+
this.ready = callCallback(this.model.initialize(), callback);
33+
}
34+
35+
async generateInternal(options, strokes) {
36+
const temperature = +options.temperature || this.defaults.temperature;
37+
const pixelFactor = +options.pixelFactor || this.defaults.pixelFactor;
38+
39+
await this.ready;
40+
if (!this.rnnState) {
41+
this.rnnState = this.model.zeroState();
42+
this.model.setPixelFactor(pixelFactor);
43+
}
44+
45+
if (Array.isArray(strokes) && strokes.length) {
46+
this.rnnState = this.model.updateStrokes(strokes, this.rnnState);
47+
}
48+
this.rnnState = this.model.update(this.penState, this.rnnState);
49+
const pdf = this.model.getPDF(this.rnnState, temperature);
50+
this.penState = this.model.sample(pdf);
51+
return this.penState;
52+
}
53+
54+
async generate(options, strokes, callback) {
55+
return callCallback(this.generateInternal(options, strokes), callback);
56+
}
57+
58+
reset() {
59+
this.penState = this.model.zeroInput();
60+
if (this.rnnState) {
61+
this.rnnState = this.model.zeroState();
62+
}
63+
}
64+
}
65+
66+
const SketchRNNGenerator = (model, callback, large = true) => new SketchRNN(model, callback, large);
67+
68+
export default SketchRNNGenerator;

src/SketchRNN/models.js

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
const models = [
2+
'alarm_clock',
3+
'ambulance',
4+
'angel',
5+
'ant',
6+
'antyoga',
7+
'backpack',
8+
'barn',
9+
'basket',
10+
'bear',
11+
'bee',
12+
'beeflower',
13+
'bicycle',
14+
'bird',
15+
'book',
16+
'brain',
17+
'bridge',
18+
'bulldozer',
19+
'bus',
20+
'butterfly',
21+
'cactus',
22+
'calendar',
23+
'castle',
24+
'cat',
25+
'catbus',
26+
'catpig',
27+
'chair',
28+
'couch',
29+
'crab',
30+
'crabchair',
31+
'crabrabbitfacepig',
32+
'cruise_ship',
33+
'diving_board',
34+
'dog',
35+
'dogbunny',
36+
'dolphin',
37+
'duck',
38+
'elephant',
39+
'elephantpig',
40+
'eye',
41+
'face',
42+
'fan',
43+
'fire_hydrant',
44+
'firetruck',
45+
'flamingo',
46+
'flower',
47+
'floweryoga',
48+
'frog',
49+
'frogsofa',
50+
'garden',
51+
'hand',
52+
'hedgeberry',
53+
'hedgehog',
54+
'helicopter',
55+
'kangaroo',
56+
'key',
57+
'lantern',
58+
'lighthouse',
59+
'lion',
60+
'lionsheep',
61+
'lobster',
62+
'map',
63+
'mermaid',
64+
'monapassport',
65+
'monkey',
66+
'mosquito',
67+
'octopus',
68+
'owl',
69+
'paintbrush',
70+
'palm_tree',
71+
'parrot',
72+
'passport',
73+
'peas',
74+
'penguin',
75+
'pig',
76+
'pigsheep',
77+
'pineapple',
78+
'pool',
79+
'postcard',
80+
'power_outlet',
81+
'rabbit',
82+
'rabbitturtle',
83+
'radio',
84+
'radioface',
85+
'rain',
86+
'rhinoceros',
87+
'rifle',
88+
'roller_coaster',
89+
'sandwich',
90+
'scorpion',
91+
'sea_turtle',
92+
'sheep',
93+
'skull',
94+
'snail',
95+
'snowflake',
96+
'speedboat',
97+
'spider',
98+
'squirrel',
99+
'steak',
100+
'stove',
101+
'strawberry',
102+
'swan',
103+
'swing_set',
104+
'the_mona_lisa',
105+
'tiger',
106+
'toothbrush',
107+
'toothpaste',
108+
'tractor',
109+
'trombone',
110+
'truck',
111+
'whale',
112+
'windmill',
113+
'yoga',
114+
'yogabicycle',
115+
'everything',
116+
];
117+
118+
const modelPaths = new Set(models);
119+
120+
export default modelPaths;

src/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import * as imageUtils from './utils/imageUtilities';
1414
import styleTransfer from './StyleTransfer/';
1515
import LSTMGenerator from './LSTM/';
1616
import pix2pix from './Pix2pix/';
17+
import SketchRNN from './SketchRNN';
1718

1819
module.exports = {
1920
imageClassifier,
@@ -25,6 +26,7 @@ module.exports = {
2526
poseNet,
2627
LSTMGenerator,
2728
pix2pix,
29+
SketchRNN,
2830
...imageUtils,
2931
tf,
3032
};

0 commit comments

Comments
 (0)