Skip to content

Commit 76977e7

Browse files
committed
tmp: sketch of fedprox implementation
1 parent abe7996 commit 76977e7

File tree

3 files changed

+54
-20
lines changed

3 files changed

+54
-20
lines changed

discojs/src/models/model.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@ import type { BatchLogs, EpochLogs } from "./logs.js";
1717
**/
1818
// TODO make it typesafe: same shape of data/input/weights
1919
export abstract class Model<D extends DataType> implements Disposable {
20+
protected prevRoundWeights: WeightsContainer | undefined;
2021
// TODO don't allow external access but upgrade train to return weights on every epoch
2122
/** Return training state */
2223
abstract get weights(): WeightsContainer;
2324
/** Set training state */
2425
abstract set weights(ws: WeightsContainer);
2526

27+
set previousRoundWeights(ws: WeightsContainer | undefined) {
28+
this.prevRoundWeights = ws
29+
}
2630
/**
2731
* Improve predictor
2832
*

discojs/src/models/tfjs.ts

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,32 +76,61 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
7676

7777
// First iteration: replace trainOnBatch with custom loss computation
7878
async trainFedProx(
79-
xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> {
80-
81-
debug(this.model.loss, this.model.losses, this.model.lossFunctions)
79+
xs: tf.Tensor, ys: tf.Tensor,
80+
): Promise<[number, number]> {
81+
let logitsTensor: tf.Tensor<tf.Rank>;
8282
const lossFunction: () => tf.Scalar = () => {
83+
// Proximal term
84+
let proximalTerm = tf.tensor(0)
85+
if (this.prevRoundWeights !== undefined) {
86+
// squared norm
87+
const norm = new WeightsContainer(this.model.getWeights())
88+
.sub(this.prevRoundWeights)
89+
.map(t => t.square().sum())
90+
.reduce((t, acc) => tf.add(t, acc)).asScalar()
91+
const mu = 1
92+
proximalTerm = tf.mul(mu / 2, norm)
93+
}
94+
8395
this.model.apply(xs)
8496
const logits = this.model.apply(xs)
85-
if (Array.isArray(logits))
86-
throw new Error('model outputs too many tensor')
87-
if (logits instanceof tf.SymbolicTensor)
88-
throw new Error('model outputs symbolic tensor')
89-
90-
// binaryCrossEntropyLoss as implemented by tensorflow.js
91-
// https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
92-
let y: tf.Tensor;
93-
y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
94-
y = tf.log(tf.div(y, tf.sub(1, y)));
95-
return tf.losses.sigmoidCrossEntropy(ys, y);
97+
if (Array.isArray(logits))
98+
throw new Error('model outputs too many tensor')
99+
if (logits instanceof tf.SymbolicTensor)
100+
throw new Error('model outputs symbolic tensor')
101+
logitsTensor = tf.keep(logits)
102+
// binaryCrossentropy as implemented by tensorflow.js
103+
// https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
104+
let y: tf.Tensor;
105+
y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
106+
y = tf.log(tf.div(y, tf.sub(1, y)));
107+
const loss = tf.losses.sigmoidCrossEntropy(ys, y);
108+
console.log(loss.dataSync(), proximalTerm.dataSync())
109+
return tf.add(loss, proximalTerm)
96110
}
97111
const lossTensor = this.model.optimizer.minimize(lossFunction, true)
98112
if (lossTensor === null) throw new Error("loss should not be null")
99-
100-
const loss = await lossTensor.array()
101-
tf.dispose([xs, ys, lossTensor])
113+
114+
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
115+
const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
116+
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
117+
const accSumTensor = accTensor.sum()
118+
const accSum = await accSumTensor.array()
119+
if (typeof accSum !== 'number')
120+
throw new Error('got multiple accuracy sum')
121+
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
122+
tf.dispose([accTensor, accSumTensor, logitsTensor])
123+
124+
const loss = await lossTensor.array()
125+
tf.dispose([xs, ys, lossTensor])
102126

103-
// dummy accuracy for now
104-
return [loss, 0]
127+
const memory = tf.memory().numBytes / 1024 / 1024 / 1024
128+
debug("training metrics: %O", {
129+
loss,
130+
memory,
131+
allocated: tf.memory().numTensors,
132+
});
133+
return [loss, accSum / accSize]
105134
}
106135

107136
async #evaluate(

discojs/src/training/trainer.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ export class Trainer<D extends DataType> {
9090
let previousRoundWeights: WeightsContainer | undefined;
9191
for (let round = 0; round < totalRound; round++) {
9292
await this.#client.onRoundBeginCommunication();
93-
93+
94+
this.model.previousRoundWeights = previousRoundWeights
9495
yield this.#runRound(dataset, validationDataset);
9596

9697
let localWeights = this.model.weights;

0 commit comments

Comments
 (0)