88// for arbitrary data though. It's worth a look :)
99import { IMAGE_H , IMAGE_W , MnistData } from './datas.js' ;
1010
11- // This is a helper class for drawing loss graphs and MNIST images to the
12- // window. For the purposes of understanding the machine learning bits, you can
13- // largely ignore it
1411import * as ui from './ui.js' ;
1512
1613
17- function createConvModel ( n_layers , n_units , hidden ) {
18-
14+ function createConvModel ( n_layers , n_units , hidden ) { //resnet-densenet-batchnorm
1915 this . latent_dim = Number ( hidden ) ; //final dimension of hidden layer
2016 this . n_layers = Number ( n_layers ) ; //how many hidden layers in encoder and decoder
2117 this . n_units = Number ( n_units ) ; //output dimension of each layer
2218 this . img_shape = [ 28 , 28 ] ;
2319 this . img_units = this . img_shape [ 0 ] * this . img_shape [ 1 ] ;
2420 // build the encoder
21+
2522 var i = tf . input ( { shape : this . img_shape } ) ;
2623 var h = tf . layers . flatten ( ) . apply ( i ) ;
27-
28- for ( var j = 0 ; j < this . n_layers ; j ++ ) {
24+ h = tf . layers . batchNormalization ( - 1 ) . apply ( h ) ;
25+ h = tf . layers . dense ( { units : this . n_units , activation :'relu' } ) . apply ( h ) ;
26+ for ( var j = 0 ; j < this . n_layers - 1 ; j ++ ) {
27+ var tm = h ;
28+ const addLayer = tf . layers . add ( ) ;
2929 var h = tf . layers . dense ( { units : this . n_units , activation :'relu' } ) . apply ( h ) ; //n hidden
30+ h = addLayer . apply ( [ tm , h ] ) ;
31+ h = tf . layers . batchNormalization ( 0 ) . apply ( h ) ;
3032 }
3133
32- var o = tf . layers . dense ( { units : this . latent_dim } ) . apply ( h ) ; //1 final
34+ var o = tf . layers . dense ( { units : this . latent_dim } ) . apply ( h ) ;
35+ //1 final
3336 this . encoder = tf . model ( { inputs : i , outputs : o } ) ;
3437
3538 // build the decoder
3639 var i = h = tf . input ( { shape : this . latent_dim } ) ;
37- for ( var j = 0 ; j < this . n_layers ; j ++ ) { //n hidden
40+ h = tf . layers . dense ( { units : this . n_units , activation :'relu' } ) . apply ( h ) ;
41+ for ( var j = 0 ; j < this . n_layers - 1 ; j ++ ) {
42+ var tm = h ;
43+ const addLayer = tf . layers . add ( ) ; //n hidden
3844 var h = tf . layers . dense ( { units : this . n_units , activation :'relu' } ) . apply ( h ) ;
45+ h = addLayer . apply ( [ tm , h ] ) ;
3946 }
40- var o = tf . layers . dense ( { units : this . img_units } ) . apply ( h ) ; //1 final
47+
48+ var o = tf . layers . dense ( { units : this . img_units } ) . apply ( h ) ; //1 final
4149 var o = tf . layers . reshape ( { targetShape : this . img_shape } ) . apply ( o ) ;
4250 this . decoder = tf . model ( { inputs : i , outputs : o } ) ;
4351
4452 // stack the autoencoder
4553 var i = tf . input ( { shape : this . img_shape } ) ;
4654 var z = this . encoder . apply ( i ) ; //z is hidden code
47-
4855 var o = this . decoder . apply ( z ) ;
4956 this . auto = tf . model ( { inputs : i , outputs : o } ) ;
5057
5158}
59+
60+
5261let epochs = 0 , trainEpochs , batch ;
5362var trainData ;
5463var testData ;
5564var b ; var model ;
5665
66+
67+
5768async function train ( model ) {
5869
5970 const e = document . getElementById ( 'batchsize' ) ;
@@ -84,8 +95,6 @@ await showPredictions(model,epochs); //Triv
8495
8596}
8697
87-
88-
8998async function showPredictions ( model , epochs ) { //Trivial Samples of autoencoder
9099 const testExamples = 10 ;
91100 const examples = data . getTestData ( testExamples ) ;
@@ -106,14 +115,15 @@ async function run(){
106115 testData = data . getTestData ( ) ;
107116}
108117
118+ document . getElementById ( 'vis' ) . oninput = function ( ) { vis = Number ( document . getElementById ( 'vis' ) . value ) ; console . log ( vis ) ; } ;
109119
110120async function load ( ) {
111121 var ele = document . getElementById ( 'barc' ) ;
112122 ele . style . display = "none" ;
113123 const n_units = document . getElementById ( 'n_units' ) . value ;
114124 const n_layers = document . getElementById ( 'n_layers' ) . value ;
115125 const hidden = document . getElementById ( 'hidden' ) . value ;
116- model = new createConvModel ( n_layers , n_units , hidden ) ;
126+ model = new createConvModel ( n_layers , n_units , hidden ) ; //load model
117127 const elem = document . getElementById ( 'new' )
118128 elem . innerHTML = "Model Created!!!"
119129 epochs = 0 ;
@@ -122,13 +132,15 @@ async function load() {
122132
123133load ( ) ;
124134
135+
136+
125137async function runtrain ( ) {
126138 var ele = document . getElementById ( 'barc' ) ;
127139 ele . style . display = "block" ;
128140 var elem = document . getElementById ( 'new' ) ;
129141 elem . innerHTML = "" ;
130142 b = 0 ;
131- await train ( model ) ;
143+ await train ( model ) ; //start training
132144 vis = Number ( document . getElementById ( 'vis' ) . value ) ;
133145}
134146
@@ -151,7 +163,7 @@ function normaltensor(prediction){
151163 prediction = prediction . sub ( inputMin ) . div ( inputMax . sub ( inputMin ) ) ;
152164 return prediction ; }
153165function normal ( prediction ) {
154- const inputMax = prediction . max ( ) ;
166+ const inputMax = prediction . max ( ) ; //normailization
155167 const inputMin = prediction . min ( ) ;
156168 prediction = prediction . sub ( inputMin ) . div ( inputMax . sub ( inputMin ) ) ;
157169 return prediction ;
@@ -163,22 +175,27 @@ const canvas=document.getElementById('celeba-scene');
163175const mot = document . getElementById ( 'mot' ) ;
164176var cont = mot . getContext ( '2d' ) ;
165177
178+
179+
180+
181+
182+
183+
184+
185+
186+
166187function sample ( obj ) { //plotting
167188 obj . x = ( obj . x ) * vis ;
168189 obj . y = ( obj . y ) * vis ;
169190 // convert 10, 50 into a vector
170191 var y = tf . tensor2d ( [ [ obj . x , obj . y ] ] ) ;
171- // sample from region 10, 50 in latent space
172192
173193 var prediction = model . decoder . predict ( y ) . dataSync ( ) ;
174-
175- //scaling
194+ //scaling
176195 prediction = normaltensor ( prediction ) ;
177196 prediction = prediction . reshape ( [ 28 , 28 ] ) ;
178197
179- prediction = prediction . mul ( 255 ) . toInt ( ) ;
180-
181-
198+ prediction = prediction . mul ( 255 ) . toInt ( ) ; //for2dplot
182199 // log the prediction to the browser console
183200 tf . browser . toPixels ( prediction , canvas ) ;
184201}
@@ -190,7 +207,7 @@ cont.fillRect(0,0,mot.width,mot.height);
190207mot . addEventListener ( 'mousemove' , function ( e ) {
191208 mouse . x = ( e . pageX - this . offsetLeft ) * 3.43 ;
192209 mouse . y = ( e . pageY - this . offsetTop ) * 1.9 ;
193- } , false ) ;
210+ } , false ) ; //mouse movement for 2dplot
194211
195212mot . addEventListener ( 'mousedown' , function ( e ) {
196213 mot . addEventListener ( 'mousemove' , on , false ) ;
@@ -209,11 +226,6 @@ var on= function() {
209226} ;
210227
211228
212-
213-
214-
215-
216-
217229function plot2d ( ) {
218230 load ( ) ;
219231 const decision = Number ( document . getElementById ( "hidden" ) . value ) ;
@@ -241,6 +253,12 @@ document.addEventListener('DOMContentLoaded',plot2d);
241253
242254
243255
256+
257+
258+
259+
260+
261+
244262const canv = document . getElementById ( 'canv' ) ;
245263const outcanv = document . getElementById ( 'outcanv' ) ;
246264var ct = outcanv . getContext ( '2d' ) ;
@@ -250,7 +268,7 @@ var ctx = canv.getContext('2d');
250268function clear ( ) {
251269 ctx . clearRect ( 0 , 0 , canv . width , canv . height ) ;
252270 ctx . fillStyle = "black" ;
253- ctx . fillRect ( 0 , 0 , canv . width , canv . height ) ;
271+ ctx . fillRect ( 0 , 0 , canv . width , canv . height ) ; //for canvas autoencoding
254272 ct . clearRect ( 0 , 0 , outcanv . width , outcanv . height ) ;
255273 ct . fillStyle = "#DDDDDD" ;
256274 ct . fillRect ( 0 , 0 , outcanv . width , outcanv . height ) ;
0 commit comments