1+ import createDebug from "debug" ;
12import { List , Map , Range } from "immutable" ;
23import * as tf from '@tensorflow/tfjs'
34
@@ -13,6 +14,8 @@ import { BatchLogs } from './index.js'
1314import { Model } from './index.js'
1415import { EpochLogs } from './logs.js'
1516
17+ const debug = createDebug ( "discojs:models:tfjs" ) ;
18+
1619type Serialized < D extends DataType > = [ D , tf . io . ModelArtifacts ] ;
1720
1821/** TensorFlow JavaScript model with standard training */
@@ -63,11 +66,71 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
6366 batch : Batched < DataFormat . ModelEncoded [ D ] > ,
6467 ) : Promise < BatchLogs > {
6568 const { xs, ys } = this . #batchToTF( batch ) ;
66- const logs = await this . model . trainOnBatch ( xs , ys ) ;
69+ const logs = await this . trainFedProx ( xs , ys ) ;
70+ // const logs = await this.model.trainOnBatch(xs, ys);
6771 tf . dispose ( [ xs , ys ] )
6872 return this . getBatchLogs ( logs )
6973 }
7074
75+ async trainFedProx (
76+ xs : tf . Tensor , ys : tf . Tensor ) : Promise < [ number , number ] > {
77+ // let logitsTensor: tf.Tensor<tf.Rank>;
78+ debug ( this . model . loss , this . model . losses , this . model . lossFunctions )
79+ const lossFunction : ( ) => tf . Scalar = ( ) => {
80+ this . model . apply ( xs )
81+ const logits = this . model . apply ( xs )
82+ if ( Array . isArray ( logits ) )
83+ throw new Error ( 'model outputs too many tensor' )
84+ if ( logits instanceof tf . SymbolicTensor )
85+ throw new Error ( 'model outputs symbolic tensor' )
86+ // logitsTensor = tf.keep(logits)
87+ // return tf.losses.softmaxCrossEntropy(ys, logits)
88+ let y : tf . Tensor ;
89+ y = tf . clipByValue ( logits , 0.00001 , 1 - 0.00001 ) ;
90+ y = tf . log ( tf . div ( y , tf . sub ( 1 , y ) ) ) ;
91+ return tf . losses . sigmoidCrossEntropy ( ys , y ) ;
92+ // return tf.losses.sigmoidCrossEntropy(ys, logits)
93+ }
94+ const lossTensor = this . model . optimizer . minimize ( lossFunction , true )
95+ if ( lossTensor === null ) throw new Error ( "loss should not be null" )
96+ // const lossTensor = tf.tidy(() => {
97+ // const { grads, value: lossTensor } = this.model.optimizer.computeGradients(() => {
98+ // const logits = this.model.apply(xs)
99+ // if (Array.isArray(logits))
100+ // throw new Error('model outputs too many tensor')
101+ // if (logits instanceof tf.SymbolicTensor)
102+ // throw new Error('model outputs symbolic tensor')
103+ // logitsTensor = tf.keep(logits)
104+ // // return tf.losses.softmaxCrossEntropy(ys, logits)
105+ // return this.model.calculateLosses(ys, logits)[0]
106+ // })
107+ // this.model.optimizer.applyGradients(grads)
108+ // return lossTensor
109+ // })
110+
111+ // // @ts -expect-error Variable 'logitsTensor' is used before being assigned
112+ // const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
113+ // const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
114+ // const accSumTensor = accTensor.sum()
115+ // const accSum = await accSumTensor.array()
116+ // if (typeof accSum !== 'number')
117+ // throw new Error('got multiple accuracy sum')
118+ // // @ts -expect-error Variable 'logitsTensor' is used before being assigned
119+ // tf.dispose([accTensor, accSumTensor, logitsTensor])
120+
121+ const loss = await lossTensor . array ( )
122+ tf . dispose ( [ xs , ys , lossTensor ] )
123+
124+ // const memory = tf.memory().numBytes / 1024 / 1024 / 1024
125+ // debug("training metrics: %O", {
126+ // loss,
127+ // memory,
128+ // allocated: tf.memory().numTensors,
129+ // });
130+ return [ loss , 0 ]
131+ // return [loss, accSum / accSize]
132+ }
133+
71134 async #evaluate(
72135 dataset : Dataset < Batched < DataFormat . ModelEncoded [ D ] > > ,
73136 ) : Promise < Record < "accuracy" | "loss" , number > > {
@@ -160,7 +223,10 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
160223 return new this (
161224 datatype ,
162225 await tf . loadLayersModel ( {
163- load : ( ) => Promise . resolve ( artifacts ) ,
226+ load : ( ) => {
227+ console . log ( "deserialize called" )
228+ return Promise . resolve ( artifacts )
229+ } ,
164230 } ) ,
165231 ) ;
166232 }
@@ -187,7 +253,7 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
187253 return [ this . datatype , await ret ]
188254 }
189255
190- [ Symbol . dispose ] ( ) : void {
256+ [ Symbol . dispose ] ( ) : void {
191257 this . model . dispose ( )
192258 }
193259
0 commit comments