@@ -76,32 +76,61 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
7676
7777 // First iteration: replace trainOnBatch with custom loss computation
7878 async trainFedProx (
79- xs : tf . Tensor , ys : tf . Tensor ) : Promise < [ number , number ] > {
80-
81- debug ( this . model . loss , this . model . losses , this . model . lossFunctions )
79+ xs : tf . Tensor , ys : tf . Tensor ,
80+ ) : Promise < [ number , number ] > {
81+ let logitsTensor : tf . Tensor < tf . Rank > ;
8282 const lossFunction : ( ) => tf . Scalar = ( ) => {
83+ // Proximal term
84+ let proximalTerm = tf . tensor ( 0 )
85+ if ( this . prevRoundWeights !== undefined ) {
86+ // squared norm
87+ const norm = new WeightsContainer ( this . model . getWeights ( ) )
88+ . sub ( this . prevRoundWeights )
89+ . map ( t => t . square ( ) . sum ( ) )
90+ . reduce ( ( t , acc ) => tf . add ( t , acc ) ) . asScalar ( )
91+ const mu = 1
92+ proximalTerm = tf . mul ( mu / 2 , norm )
93+ }
94+
8395 this . model . apply ( xs )
8496 const logits = this . model . apply ( xs )
85- if ( Array . isArray ( logits ) )
86- throw new Error ( 'model outputs too many tensor' )
87- if ( logits instanceof tf . SymbolicTensor )
88- throw new Error ( 'model outputs symbolic tensor' )
89-
90- // binaryCrossEntropyLoss as implemented by tensorflow.js
91- // https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
92- let y : tf . Tensor ;
93- y = tf . clipByValue ( logits , 0.00001 , 1 - 0.00001 ) ;
94- y = tf . log ( tf . div ( y , tf . sub ( 1 , y ) ) ) ;
95- return tf . losses . sigmoidCrossEntropy ( ys , y ) ;
97+ if ( Array . isArray ( logits ) )
98+ throw new Error ( 'model outputs too many tensor' )
99+ if ( logits instanceof tf . SymbolicTensor )
100+ throw new Error ( 'model outputs symbolic tensor' )
101+ logitsTensor = tf . keep ( logits )
102+ // binaryCrossentropy as implemented by tensorflow.js
103+ // https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
104+ let y : tf . Tensor ;
105+ y = tf . clipByValue ( logits , 0.00001 , 1 - 0.00001 ) ;
106+ y = tf . log ( tf . div ( y , tf . sub ( 1 , y ) ) ) ;
107+ const loss = tf . losses . sigmoidCrossEntropy ( ys , y ) ;
108+ console . log ( loss . dataSync ( ) , proximalTerm . dataSync ( ) )
109+ return tf . add ( loss , proximalTerm )
96110 }
97111 const lossTensor = this . model . optimizer . minimize ( lossFunction , true )
98112 if ( lossTensor === null ) throw new Error ( "loss should not be null" )
99-
100- const loss = await lossTensor . array ( )
101- tf . dispose ( [ xs , ys , lossTensor ] )
113+
114+ // @ts -expect-error Variable 'logitsTensor' is used before being assigned
115+ const accTensor = tf . metrics . categoricalAccuracy ( ys , logitsTensor )
116+ const accSize = accTensor . shape . reduce ( ( l , r ) => l * r , 1 )
117+ const accSumTensor = accTensor . sum ( )
118+ const accSum = await accSumTensor . array ( )
119+ if ( typeof accSum !== 'number' )
120+ throw new Error ( 'got multiple accuracy sum' )
121+ // @ts -expect-error Variable 'logitsTensor' is used before being assigned
122+ tf . dispose ( [ accTensor , accSumTensor , logitsTensor ] )
123+
124+ const loss = await lossTensor . array ( )
125+ tf . dispose ( [ xs , ys , lossTensor ] )
102126
103- // dummy accuracy for now
104- return [ loss , 0 ]
127+ const memory = tf . memory ( ) . numBytes / 1024 / 1024 / 1024
128+ debug ( "training metrics: %O" , {
129+ loss,
130+ memory,
131+ allocated : tf . memory ( ) . numTensors ,
132+ } ) ;
133+ return [ loss , accSum / accSize ]
105134 }
106135
107136 async #evaluate(
0 commit comments