@@ -12,10 +12,21 @@ SketchRNN
12
12
import * as ms from '@magenta/sketch' ;
13
13
import callCallback from '../utils/callcallback' ;
14
14
import modelPaths from './models' ;
15
+ import modelLoader from '../utils/modelLoader' ;
15
16
16
- const PATH_START_LARGE = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/' ;
17
- const PATH_START_SMALL = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/models/' ;
18
- const PATH_END = '.gen.json' ;
17
+ // const PATH_START_LARGE = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/';
18
+ // const PATH_START_SMALL = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/models/';
19
+ // const PATH_END = '.gen.json';
20
+
21
+
22
+ const DEFAULTS = {
23
+ modelPath : 'https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/' ,
24
+ modelPath_large : 'https://storage.googleapis.com/quickdraw-models/sketchRNN/models/' ,
25
+ modelPath_small : 'https://storage.googleapis.com/quickdraw-models/sketchRNN/models/' ,
26
+ PATH_END : '.gen.json' ,
27
+ temperature : 0.65 ,
28
+ pixelFactor : 3.0 ,
29
+ }
19
30
20
31
class SketchRNN {
21
32
/**
@@ -27,21 +38,37 @@ class SketchRNN {
27
38
*/
28
39
constructor ( model , callback , large = true ) {
29
40
let checkpointUrl = model ;
30
- if ( modelPaths . has ( checkpointUrl ) ) {
31
- checkpointUrl = ( large ? PATH_START_LARGE : PATH_START_SMALL ) + checkpointUrl + PATH_END ;
32
- }
33
- this . defaults = {
41
+
42
+ this . config = {
34
43
temperature : 0.65 ,
35
44
pixelFactor : 3.0 ,
45
+ modelPath : DEFAULTS . modelPath ,
46
+ modelPath_small : DEFAULTS . modelPath_small ,
47
+ modelPath_large : DEFAULTS . modelPath_large ,
48
+ PATH_END : DEFAULTS . PATH_END ,
36
49
} ;
37
- this . model = new ms . SketchRNN ( checkpointUrl ) ;
50
+
51
+
52
+ if ( modelLoader . isAbsoluteURL ( checkpointUrl ) === true ) {
53
+ const modelPath = modelLoader . getModelPath ( checkpointUrl ) ;
54
+ this . config . modelPath = modelPath ;
55
+
56
+ } else if ( modelPaths . has ( checkpointUrl ) ) {
57
+ checkpointUrl = ( large ? this . config . modelPath : this . config . modelPath_small ) + checkpointUrl + this . config . PATH_END ;
58
+ this . config . modelPath = checkpointUrl ;
59
+ } else {
60
+ console . log ( 'no model found!' ) ;
61
+ return this ;
62
+ }
63
+
64
+ this . model = new ms . SketchRNN ( this . config . modelPath ) ;
38
65
this . penState = this . model . zeroInput ( ) ;
39
66
this . ready = callCallback ( this . model . initialize ( ) , callback ) ;
40
67
}
41
68
42
69
async generateInternal ( options , strokes ) {
43
- const temperature = + options . temperature || this . defaults . temperature ;
44
- const pixelFactor = + options . pixelFactor || this . defaults . pixelFactor ;
70
+ const temperature = + options . temperature || this . config . temperature ;
71
+ const pixelFactor = + options . pixelFactor || this . config . pixelFactor ;
45
72
46
73
await this . ready ;
47
74
if ( ! this . rnnState ) {
0 commit comments