@@ -73,62 +73,60 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
7373 }
7474
7575 async trainFedProx (
76- xs : tf . Tensor , ys : tf . Tensor ) : Promise < [ number , number ] > {
77- // let logitsTensor: tf.Tensor<tf.Rank>;
78- debug ( this . model . loss , this . model . losses , this . model . lossFunctions )
76+ xs : tf . Tensor , ys : tf . Tensor ,
77+ ) : Promise < [ number , number ] > {
78+ let logitsTensor : tf . Tensor < tf . Rank > ;
7979 const lossFunction : ( ) => tf . Scalar = ( ) => {
80+ // Proximal term
81+ let proximalTerm = tf . tensor ( 0 )
82+ if ( this . prevRoundWeights !== undefined ) {
83+ // squared norm
84+ const norm = new WeightsContainer ( this . model . getWeights ( ) )
85+ . sub ( this . prevRoundWeights )
86+ . map ( t => t . square ( ) . sum ( ) )
87+ . reduce ( ( t , acc ) => tf . add ( t , acc ) ) . asScalar ( )
88+ const mu = 1
89+ proximalTerm = tf . mul ( mu / 2 , norm )
90+ }
91+
8092 this . model . apply ( xs )
8193 const logits = this . model . apply ( xs )
82- if ( Array . isArray ( logits ) )
83- throw new Error ( 'model outputs too many tensor' )
84- if ( logits instanceof tf . SymbolicTensor )
85- throw new Error ( 'model outputs symbolic tensor' )
86- // logitsTensor = tf.keep(logits)
87- // return tf.losses.softmaxCrossEntropy(ys, logits)
88- let y : tf . Tensor ;
89- y = tf . clipByValue ( logits , 0.00001 , 1 - 0.00001 ) ;
90- y = tf . log ( tf . div ( y , tf . sub ( 1 , y ) ) ) ;
91- return tf . losses . sigmoidCrossEntropy ( ys , y ) ;
92- // return tf.losses.sigmoidCrossEntropy(ys, logits)
94+ if ( Array . isArray ( logits ) )
95+ throw new Error ( 'model outputs too many tensor' )
96+ if ( logits instanceof tf . SymbolicTensor )
97+ throw new Error ( 'model outputs symbolic tensor' )
98+ logitsTensor = tf . keep ( logits )
99+ // binaryCrossEntropy
100+ let y : tf . Tensor ;
101+ y = tf . clipByValue ( logits , 0.00001 , 1 - 0.00001 ) ;
102+ y = tf . log ( tf . div ( y , tf . sub ( 1 , y ) ) ) ;
103+ const loss = tf . losses . sigmoidCrossEntropy ( ys , y ) ;
104+ console . log ( loss . dataSync ( ) , proximalTerm . dataSync ( ) )
105+ return tf . add ( loss , proximalTerm )
93106 }
94107 const lossTensor = this . model . optimizer . minimize ( lossFunction , true )
95108 if ( lossTensor === null ) throw new Error ( "loss should not be null" )
96- // const lossTensor = tf.tidy(() => {
97- // const { grads, value: lossTensor } = this.model.optimizer.computeGradients(() => {
98- // const logits = this.model.apply(xs)
99- // if (Array.isArray(logits))
100- // throw new Error('model outputs too many tensor')
101- // if (logits instanceof tf.SymbolicTensor)
102- // throw new Error('model outputs symbolic tensor')
103- // logitsTensor = tf.keep(logits)
104- // // return tf.losses.softmaxCrossEntropy(ys, logits)
105- // return this.model.calculateLosses(ys, logits)[0]
106- // })
107- // this.model.optimizer.applyGradients(grads)
108- // return lossTensor
109- // })
110109
111- // // @ts -expect-error Variable 'logitsTensor' is used before being assigned
112- // const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
113- // const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
114- // const accSumTensor = accTensor.sum()
115- // const accSum = await accSumTensor.array()
116- // if (typeof accSum !== 'number')
117- // throw new Error('got multiple accuracy sum')
118- // // @ts -expect-error Variable 'logitsTensor' is used before being assigned
119- // tf.dispose([accTensor, accSumTensor, logitsTensor])
110+ // @ts -expect-error Variable 'logitsTensor' is used before being assigned
111+ const accTensor = tf . metrics . categoricalAccuracy ( ys , logitsTensor )
112+ const accSize = accTensor . shape . reduce ( ( l , r ) => l * r , 1 )
113+ const accSumTensor = accTensor . sum ( )
114+ const accSum = await accSumTensor . array ( )
115+ if ( typeof accSum !== 'number' )
116+ throw new Error ( 'got multiple accuracy sum' )
117+ // @ts -expect-error Variable 'logitsTensor' is used before being assigned
118+ tf . dispose ( [ accTensor , accSumTensor , logitsTensor ] )
120119
121120 const loss = await lossTensor . array ( )
122121 tf . dispose ( [ xs , ys , lossTensor ] )
123122
124- // const memory = tf.memory().numBytes / 1024 / 1024 / 1024
125- // debug("training metrics: %O", {
126- // loss,
127- // memory,
128- // allocated: tf.memory().numTensors,
129- // });
130- return [ loss , 0 ]
131- // return [loss, accSum / accSize]
123+ const memory = tf . memory ( ) . numBytes / 1024 / 1024 / 1024
124+ debug ( "training metrics: %O" , {
125+ loss,
126+ memory,
127+ allocated : tf . memory ( ) . numTensors ,
128+ } ) ;
129+ return [ loss , accSum / accSize ]
132130 }
133131
134132 async #evaluate(
0 commit comments