Skip to content

Commit 42f2e3b

Browse files
committed
Secure-history aggr, test fix
1 parent 1bcd89d commit 42f2e3b

File tree

2 files changed

+148
-127
lines changed

2 files changed

+148
-127
lines changed

discojs/src/aggregator/aggregator.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon
8989
throw new Error("Tried adding an invalid contribution. Handle this case before calling add.")
9090

9191
// call the abstract method _add, implemented by subclasses
92-
this._add(nodeId, contribution, communicationRound)
92+
this._add(nodeId, contribution, communicationRound ?? this.communicationRound)
9393
// If the aggregator has enough contributions then aggregate the weights
9494
// and emit the 'aggregation' event
9595
if (this.isFull()) {
Lines changed: 147 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,141 +1,162 @@
11
import { List, Set, Range, Map } from "immutable";
22
import { assert, expect } from "chai";
3+
import * as tf from "@tensorflow/tfjs";
34

45
import {
5-
aggregator as aggregators,
6-
aggregation,
7-
WeightsContainer,
6+
aggregator as aggregators,
7+
aggregation,
8+
WeightsContainer,
89
} from "../index.js";
910

10-
import { SecureHistoryAggregator } from "./secure_history.js";
11+
import { SecureHistoryAggregator } from "./secure_history.js";
1112
import { SecureAggregator } from "./secure.js";
1213

1314
import { wsIntoArrays, communicate, setupNetwork } from "../aggregator.spec.js";
1415

15-
describe("secure history aggregator", function () {
16-
const epsilon = 1e-4;
17-
18-
const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]);
19-
const secrets = List.of(
20-
WeightsContainer.of([1, 2, 3, -1], [-5, 6]),
21-
WeightsContainer.of([2, 3, 7, 1], [-10, 5]),
22-
WeightsContainer.of([3, 1, 5, 3], [-15, 19]),
23-
);
24-
25-
function buildShares(): List<List<WeightsContainer>> {
26-
const nodes = Set(secrets.keys()).map(String);
27-
return secrets.map((secret) => {
28-
const aggregator = new SecureHistoryAggregator();
29-
aggregator.setNodes(nodes);
30-
return aggregator.generateAllShares(secret);
31-
});
32-
}
33-
34-
function buildPartialSums(
35-
allShares: List<List<WeightsContainer>>,
36-
): List<WeightsContainer> {
37-
return Range(0, secrets.size)
38-
.map((idx) => allShares.map((shares) => shares.get(idx)))
39-
.map((shares) => aggregation.sum(shares as List<WeightsContainer>))
40-
.toList();
41-
}
42-
43-
it("recovers secrets from shares", () => {
44-
const recovered = buildShares().map((shares) => aggregation.sum(shares));
45-
assert.isTrue(
46-
(
47-
recovered.zip(secrets) as List<[WeightsContainer, WeightsContainer]>
48-
).every(([actual, expected]) => actual.equals(expected, epsilon)),
16+
describe("Secure history aggregator", function () {
17+
const epsilon = 1e-4;
18+
19+
const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]);
20+
const secrets = List.of(
21+
WeightsContainer.of([1, 2, 3, -1], [-5, 6]),
22+
WeightsContainer.of([2, 3, 7, 1], [-10, 5]),
23+
WeightsContainer.of([3, 1, 5, 3], [-15, 19]),
4924
);
50-
});
51-
52-
it("aggregates partial sums with momentum smoothing", () => {
53-
const aggregator = new SecureHistoryAggregator(100, 0.8);
54-
const nodes = Set(secrets.keys()).map(String);
55-
aggregator.setNodes(nodes);
56-
57-
// simulate first communication round contributions (shares)
58-
const sharesRound0 = buildShares();
59-
sharesRound0.forEach((shares, idx) => {
60-
shares.forEach((share, nodeIdx) => {
61-
aggregator.add(nodeIdx.toString(), share, 0);
62-
});
63-
});
6425

65-
// aggregate round 0 sums
66-
const sumRound0 = aggregator.aggregate();
67-
expect(sumRound0.equals(aggregation.sum(sharesRound0.get(0)!), epsilon)).to.be.true;
68-
69-
// // simulate second communication round partial sums
70-
// const partialSums = buildPartialSums(sharesRound0);
71-
// partialSums.forEach((partialSum, nodeIdx) => {
72-
// aggregator.add(nodeIdx.toString(), partialSum, 1);
73-
// });
74-
75-
// // First aggregation with momentum - no previous momentum, so just average
76-
// let agg1 = aggregator.aggregate();
77-
// const avgPartialSum = aggregation.avg(partialSums);
78-
// expect(agg1.equals(avgPartialSum, epsilon)).to.be.true;
79-
80-
// // Add another set of partial sums with slight modification
81-
// const partialSums2 = partialSums.map(ws =>
82-
// ws.map(t => t.mul(1.1))
83-
// );
84-
85-
// partialSums2.forEach((partialSum, nodeIdx) => {
86-
// aggregator.add(nodeIdx.toString(), partialSum, 1);
87-
// });
88-
89-
// // Now momentum should smooth the updated average and previous aggregate
90-
// const agg2 = aggregator.aggregate();
91-
92-
// // agg2 should be between avgPartialSum and new partial sums average weighted by beta
93-
// const avgPartialSum2 = aggregation.avg(partialSums2);
94-
// // expected = beta * agg1 + (1 - beta) * avgPartialSum2
95-
// const expectedAgg2 = agg1.mapWith(avgPartialSum2, (a, b) =>
96-
// a.mul(aggregator['beta']).add(b.mul(1 - aggregator['beta']))
97-
// );
98-
99-
// // Compare agg2 and expectedAgg2 elementwise
100-
// expect(agg2.equals(expectedAgg2, epsilon)).to.be.true;
101-
});
102-
103-
it("behaves similar to SecureAggregator without momentum (beta=0)", async () => {
104-
class TestSecureHistoryAggregator extends SecureHistoryAggregator {
105-
constructor() {
106-
super(0, 0); // beta=0 disables momentum smoothing
107-
}
26+
function buildShares(): List<List<WeightsContainer>> {
27+
const nodes = Set(secrets.keys()).map(String);
28+
return secrets.map((secret) => {
29+
const aggregator = new SecureHistoryAggregator();
30+
aggregator.setNodes(nodes);
31+
return aggregator.generateAllShares(secret);
32+
});
10833
}
109-
const secureHistoryNetwork = setupNetwork(TestSecureHistoryAggregator); // beta=0 disables momentum smoothing
110-
const secureNetwork = setupNetwork(SecureAggregator);
111-
112-
const secureHistoryResults = await communicate(
113-
Map(
114-
secureHistoryNetwork
115-
.entrySeq()
116-
.zip(Range(0, 3))
117-
.map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]),
118-
),
119-
0,
120-
);
121-
const secureResults = await communicate(
122-
Map(
123-
secureNetwork
124-
.entrySeq()
125-
.zip(Range(0, 3))
126-
.map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]),
127-
),
128-
0,
129-
);
13034

131-
List(await Promise.all(secureHistoryResults.sort().valueSeq().map(wsIntoArrays)))
132-
.flatMap((x) => x)
133-
.flatMap((x) => x)
134-
.zipAll(
135-
List(await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays)))
136-
.flatMap((x) => x)
137-
.flatMap((x) => x),
138-
)
139-
.forEach(([secureHistory, secure]) => expect(secureHistory).to.be.closeTo(secure, 0.001));
140-
});
35+
function buildPartialSums(
36+
allShares: List<List<WeightsContainer>>,
37+
): List<WeightsContainer> {
38+
return Range(0, secrets.size)
39+
.map((idx) => allShares.map((shares) => shares.get(idx)))
40+
.map((shares) => aggregation.sum(shares as List<WeightsContainer>))
41+
.toList();
42+
}
43+
44+
it("recovers secrets from shares", () => {
45+
const recovered = buildShares().map((shares) => aggregation.sum(shares));
46+
assert.isTrue(
47+
(
48+
recovered.zip(secrets) as List<[WeightsContainer, WeightsContainer]>
49+
).every(([actual, expected]) => actual.equals(expected, epsilon)),
50+
);
51+
});
52+
53+
it("aggregates partial sums with momentum smoothing", async () => {
54+
const aggregator = new SecureHistoryAggregator(100, 0.8);
55+
const nodes = Set(secrets.keys()).map(String);
56+
aggregator.setNodes(nodes);
57+
58+
// Prepare to capture aggregation result
59+
const aggregationPromise = aggregator.getPromiseForAggregation();
60+
61+
const sharesRound0 = buildShares();
62+
63+
let partialSums = Range(0, nodes.size).map((receiverIdx) => {
64+
const receivedShares = sharesRound0.map(shares => shares.get(receiverIdx)!);
65+
return aggregation.sum(receivedShares as List<WeightsContainer>);
66+
}).toList();
67+
68+
// Add one total contribution per node
69+
partialSums.forEach((partialSum, idx) => {
70+
const nodeId = idx.toString();
71+
aggregator.add(nodeId, partialSum, 0);
72+
});
73+
74+
const sumRound0 = await aggregationPromise;
75+
76+
const expectedSum = aggregation.sum(
77+
sharesRound0.flatMap(x => x) // flatten to List<WeightsContainer>
78+
);
79+
expect(sumRound0.equals(expectedSum, epsilon)).to.be.true;
80+
81+
82+
// simulate second communication round partial sums
83+
const aggregationPromise2 = aggregator.getPromiseForAggregation();
84+
85+
partialSums.forEach((partialSum, idx) => {
86+
const nodeId = idx.toString();
87+
aggregator.add(nodeId, partialSum, 0);
88+
});
89+
const sumRound1 = await aggregationPromise2;
90+
91+
// First aggregation with momentum - no previous momentum, so just average
92+
const avgPartialSum = aggregation.avg(partialSums);
93+
expect(sumRound1.equals(avgPartialSum, epsilon)).to.be.true;
94+
95+
const dummyPromise = aggregator.getPromiseForAggregation();
96+
partialSums.forEach((partialSum, idx) => {
97+
const nodeId = idx.toString();
98+
aggregator.add(nodeId, partialSum, 1); // round 0 of next aggregation round
99+
});
100+
await dummyPromise;
101+
102+
const aggregationPromise3 = aggregator.getPromiseForAggregation();
103+
// Add another set of partial sums with slight modification
104+
const partialSums2 = partialSums.map(ws =>
105+
ws.map((tensor) => tf.mul(tensor, 1.1))
106+
);
107+
108+
// Step 3: Add new partial sums to aggregator
109+
partialSums2.forEach((partialSum, idx) => {
110+
const nodeId = idx.toString();
111+
aggregator.add(nodeId, partialSum, 1);
112+
});
113+
const sumRound2 = await aggregationPromise3;
114+
115+
const avgPartialSum2 = aggregation.avg(partialSums2);
116+
const expectedSumRound2 = avgPartialSum.mapWith(avgPartialSum2, (prev, curr) =>
117+
prev.mul(0.8).add(curr.mul(0.2)) // 0.8 = beta, 0.2 = (1 - beta)
118+
);
119+
120+
// Compare the actual result to the expected smoothed result
121+
expect(sumRound2.equals(expectedSumRound2, 1e-3)).to.be.true;
122+
});
123+
124+
it("behaves similar to SecureAggregator without momentum (beta=0)", async () => {
125+
class TestSecureHistoryAggregator extends SecureHistoryAggregator {
126+
constructor() {
127+
super(0, 0); // beta=0 disables momentum smoothing
128+
}
129+
}
130+
const secureHistoryNetwork = setupNetwork(TestSecureHistoryAggregator); // beta=0 disables momentum smoothing
131+
const secureNetwork = setupNetwork(SecureAggregator);
132+
133+
const secureHistoryResults = await communicate(
134+
Map(
135+
secureHistoryNetwork
136+
.entrySeq()
137+
.zip(Range(0, 3))
138+
.map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]),
139+
),
140+
0,
141+
);
142+
const secureResults = await communicate(
143+
Map(
144+
secureNetwork
145+
.entrySeq()
146+
.zip(Range(0, 3))
147+
.map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]),
148+
),
149+
0,
150+
);
151+
152+
List(await Promise.all(secureHistoryResults.sort().valueSeq().map(wsIntoArrays)))
153+
.flatMap((x) => x)
154+
.flatMap((x) => x)
155+
.zipAll(
156+
List(await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays)))
157+
.flatMap((x) => x)
158+
.flatMap((x) => x),
159+
)
160+
.forEach(([secureHistory, secure]) => expect(secureHistory).to.be.closeTo(secure, 0.001));
161+
});
141162
});

0 commit comments

Comments
 (0)