Skip to content

Commit 1bcd89d

Browse files
committed
Secure aggr with momentums
1 parent c4b5c70 commit 1bcd89d

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import { List, Set, Range, Map } from "immutable";
2+
import { assert, expect } from "chai";
3+
4+
import {
5+
aggregator as aggregators,
6+
aggregation,
7+
WeightsContainer,
8+
} from "../index.js";
9+
10+
import { SecureHistoryAggregator } from "./secure_history.js";
11+
import { SecureAggregator } from "./secure.js";
12+
13+
import { wsIntoArrays, communicate, setupNetwork } from "../aggregator.spec.js";
14+
15+
describe("secure history aggregator", function () {
16+
const epsilon = 1e-4;
17+
18+
const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]);
19+
const secrets = List.of(
20+
WeightsContainer.of([1, 2, 3, -1], [-5, 6]),
21+
WeightsContainer.of([2, 3, 7, 1], [-10, 5]),
22+
WeightsContainer.of([3, 1, 5, 3], [-15, 19]),
23+
);
24+
25+
function buildShares(): List<List<WeightsContainer>> {
26+
const nodes = Set(secrets.keys()).map(String);
27+
return secrets.map((secret) => {
28+
const aggregator = new SecureHistoryAggregator();
29+
aggregator.setNodes(nodes);
30+
return aggregator.generateAllShares(secret);
31+
});
32+
}
33+
34+
function buildPartialSums(
35+
allShares: List<List<WeightsContainer>>,
36+
): List<WeightsContainer> {
37+
return Range(0, secrets.size)
38+
.map((idx) => allShares.map((shares) => shares.get(idx)))
39+
.map((shares) => aggregation.sum(shares as List<WeightsContainer>))
40+
.toList();
41+
}
42+
43+
it("recovers secrets from shares", () => {
44+
const recovered = buildShares().map((shares) => aggregation.sum(shares));
45+
assert.isTrue(
46+
(
47+
recovered.zip(secrets) as List<[WeightsContainer, WeightsContainer]>
48+
).every(([actual, expected]) => actual.equals(expected, epsilon)),
49+
);
50+
});
51+
52+
it("aggregates partial sums with momentum smoothing", () => {
53+
const aggregator = new SecureHistoryAggregator(100, 0.8);
54+
const nodes = Set(secrets.keys()).map(String);
55+
aggregator.setNodes(nodes);
56+
57+
// simulate first communication round contributions (shares)
58+
const sharesRound0 = buildShares();
59+
sharesRound0.forEach((shares, idx) => {
60+
shares.forEach((share, nodeIdx) => {
61+
aggregator.add(nodeIdx.toString(), share, 0);
62+
});
63+
});
64+
65+
// aggregate round 0 sums
66+
const sumRound0 = aggregator.aggregate();
67+
expect(sumRound0.equals(aggregation.sum(sharesRound0.get(0)!), epsilon)).to.be.true;
68+
69+
// // simulate second communication round partial sums
70+
// const partialSums = buildPartialSums(sharesRound0);
71+
// partialSums.forEach((partialSum, nodeIdx) => {
72+
// aggregator.add(nodeIdx.toString(), partialSum, 1);
73+
// });
74+
75+
// // First aggregation with momentum - no previous momentum, so just average
76+
// let agg1 = aggregator.aggregate();
77+
// const avgPartialSum = aggregation.avg(partialSums);
78+
// expect(agg1.equals(avgPartialSum, epsilon)).to.be.true;
79+
80+
// // Add another set of partial sums with slight modification
81+
// const partialSums2 = partialSums.map(ws =>
82+
// ws.map(t => t.mul(1.1))
83+
// );
84+
85+
// partialSums2.forEach((partialSum, nodeIdx) => {
86+
// aggregator.add(nodeIdx.toString(), partialSum, 1);
87+
// });
88+
89+
// // Now momentum should smooth the updated average and previous aggregate
90+
// const agg2 = aggregator.aggregate();
91+
92+
// // agg2 should be between avgPartialSum and new partial sums average weighted by beta
93+
// const avgPartialSum2 = aggregation.avg(partialSums2);
94+
// // expected = beta * agg1 + (1 - beta) * avgPartialSum2
95+
// const expectedAgg2 = agg1.mapWith(avgPartialSum2, (a, b) =>
96+
// a.mul(aggregator['beta']).add(b.mul(1 - aggregator['beta']))
97+
// );
98+
99+
// // Compare agg2 and expectedAgg2 elementwise
100+
// expect(agg2.equals(expectedAgg2, epsilon)).to.be.true;
101+
});
102+
103+
it("behaves similar to SecureAggregator without momentum (beta=0)", async () => {
104+
class TestSecureHistoryAggregator extends SecureHistoryAggregator {
105+
constructor() {
106+
super(0, 0); // beta=0 disables momentum smoothing
107+
}
108+
}
109+
const secureHistoryNetwork = setupNetwork(TestSecureHistoryAggregator); // beta=0 disables momentum smoothing
110+
const secureNetwork = setupNetwork(SecureAggregator);
111+
112+
const secureHistoryResults = await communicate(
113+
Map(
114+
secureHistoryNetwork
115+
.entrySeq()
116+
.zip(Range(0, 3))
117+
.map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]),
118+
),
119+
0,
120+
);
121+
const secureResults = await communicate(
122+
Map(
123+
secureNetwork
124+
.entrySeq()
125+
.zip(Range(0, 3))
126+
.map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]),
127+
),
128+
0,
129+
);
130+
131+
List(await Promise.all(secureHistoryResults.sort().valueSeq().map(wsIntoArrays)))
132+
.flatMap((x) => x)
133+
.flatMap((x) => x)
134+
.zipAll(
135+
List(await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays)))
136+
.flatMap((x) => x)
137+
.flatMap((x) => x),
138+
)
139+
.forEach(([secureHistory, secure]) => expect(secureHistory).to.be.closeTo(secure, 0.001));
140+
});
141+
});
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import type { WeightsContainer, client } from "../index.js";
2+
import { SecureAggregator } from "./secure.js";
3+
import * as tf from "@tensorflow/tfjs";
4+
import { aggregation } from "../index.js";
5+
6+
export class SecureHistoryAggregator extends SecureAggregator {
7+
private prevAggregate: WeightsContainer | null = null;
8+
private readonly beta: number;
9+
10+
constructor(maxShareValue = 100, beta = 0.9) {
11+
super(maxShareValue);
12+
this.beta = beta;
13+
this.prevAggregate = null;
14+
}
15+
16+
override aggregate(): WeightsContainer {
17+
// Call the base class aggregate for rounds other than 1
18+
if (this.communicationRound !== 1) {
19+
return super.aggregate();
20+
}
21+
22+
// For communication round 1, do average + momentum smoothing
23+
const currentContributions = this.contributions.get(1);
24+
if (!currentContributions) throw new Error("aggregating without any contribution");
25+
26+
const avg = aggregation.avg(currentContributions.values());
27+
28+
if (this.prevAggregate === null) {
29+
this.prevAggregate = avg;
30+
return avg;
31+
}
32+
33+
const updatedMomentum = this.prevAggregate.mapWith(avg, (prevT, currT) =>
34+
prevT.mul(this.beta).add(currT.mul(1 - this.beta))
35+
);
36+
37+
// Dispose old tensors to avoid memory leaks
38+
this.prevAggregate.weights.forEach(t => t.dispose());
39+
this.prevAggregate = updatedMomentum;
40+
41+
return updatedMomentum;
42+
}
43+
}

0 commit comments

Comments
 (0)