Skip to content

Commit b8ead62

Browse files
committed
Add more tests for Byzantine roubust aggregator
1 parent aa89f7b commit b8ead62

File tree

4 files changed

+304
-1
lines changed

4 files changed

+304
-1
lines changed

discojs/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"@tensorflow/tfjs-node": "4",
3535
"@types/simple-peer": "9",
3636
"nodemon": "3",
37-
"ts-node": "10"
37+
"ts-node": "10",
38+
"fast-check": "^3"
3839
}
3940
}

discojs/src/aggregator/byzantine.spec.ts

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import fc from "fast-check";
12
import { Set } from "immutable";
23
import { describe, expect, it } from "vitest";
34

@@ -100,4 +101,211 @@ describe("ByzantineRobustAggregator", () => {
100101
const arr2 = await WSIntoArrays(out2);
101102
expect(arr2[0][0]).to.equal(20);
102103
});
104+
105+
it("applies momentum before aggregation", async () => {
106+
const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1e6, 1, 0.5);
107+
const [a, b] = ["a", "b"];
108+
agg.setNodes(Set.of(a, b));
109+
110+
const p1 = agg.getPromiseForAggregation();
111+
agg.add(a, WeightsContainer.of([0]), 0);
112+
agg.add(b, WeightsContainer.of([10]), 0);
113+
await p1;
114+
115+
const p2 = agg.getPromiseForAggregation();
116+
agg.add(a, WeightsContainer.of([0]), 1);
117+
agg.add(b, WeightsContainer.of([20]), 1);
118+
const out = await p2;
119+
120+
const arr = await WSIntoArrays(out);
121+
expect(arr[0][0]).to.be.closeTo(7.5, 1e-6); // mean of (0, 15)
122+
});
123+
124+
it("beta = 1 freezes aggregation after first round", async () => {
125+
const agg = new ByzantineRobustAggregator(0, 1, 'absolute', 1e6, 1, 1);
126+
const id = "c1";
127+
agg.setNodes(Set.of(id));
128+
129+
const p1 = agg.getPromiseForAggregation();
130+
agg.add(id, WeightsContainer.of([5]), 0);
131+
await p1;
132+
133+
const p2 = agg.getPromiseForAggregation();
134+
agg.add(id, WeightsContainer.of([100]), 1);
135+
const out = await p2;
136+
137+
const arr = await WSIntoArrays(out);
138+
expect(arr[0][0]).to.equal(5);
139+
});
140+
141+
it("remains robust with 30% Byzantine clients", async () => {
142+
const honest = Array(7).fill(1);
143+
const byzantine = Array(3).fill(100);
144+
145+
const agg = new ByzantineRobustAggregator(0, 10, 'absolute', 1.0, 5, 0);
146+
const ids = [...honest, ...byzantine].map((_, i) => `c${i}`);
147+
agg.setNodes(Set(ids));
148+
149+
const p = agg.getPromiseForAggregation();
150+
honest.forEach((v, i) => agg.add(`c${i}`, WeightsContainer.of([v]), 0));
151+
byzantine.forEach((v, i) => agg.add(`c${i + honest.length}`, WeightsContainer.of([v]), 0));
152+
153+
const out = await p;
154+
const arr = await WSIntoArrays(out);
155+
156+
const honestMean = honest.reduce((a, b) => a + b, 0) / honest.length;
157+
const rawMean = [...honest, ...byzantine].reduce((a, b) => a + b, 0) / (honest.length + byzantine.length);
158+
159+
expect(Math.abs(arr[0][0] - honestMean)).to.be.lessThan(Math.abs(rawMean - honestMean));
160+
});
161+
162+
it("stays close to the honest signal under constant input", async () => {
163+
const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 3, 0.9);
164+
const ids = ["a", "b", "c"];
165+
agg.setNodes(Set(ids));
166+
167+
for (let r = 0; r < 10; r++) {
168+
const p = agg.getPromiseForAggregation();
169+
ids.forEach(id => agg.add(id, WeightsContainer.of([1]), r));
170+
const out = await p;
171+
const arr = await WSIntoArrays(out);
172+
173+
expect(Math.abs(arr[0][0] - 1)).to.be.lessThan(0.3);
174+
}
175+
});
176+
177+
it("bounds the marginal influence of a single Byzantine client", async () => {
178+
const clipRadius = 1.0;
179+
180+
await fc.assert(
181+
fc.asyncProperty(
182+
fc.array(fc.double({ min: -1, max: 1 }), { minLength: 3, maxLength: 10 }),
183+
async (honest) => {
184+
const n = honest.length + 1;
185+
186+
// aggregation without Byzantine
187+
const aggClean = new ByzantineRobustAggregator(0, honest.length, "absolute", clipRadius, 1, 0);
188+
const honestIds = honest.map((_, i) => `h${i}`);
189+
aggClean.setNodes(Set(honestIds));
190+
191+
const pClean = aggClean.getPromiseForAggregation();
192+
honest.forEach((v, i) => aggClean.add(`h${i}`, WeightsContainer.of([v]), 0));
193+
const cleanOut = await pClean;
194+
const clean = (await cleanOut.weights[0].data())[0];
195+
196+
// aggregation with Byzantine
197+
const aggByz = new ByzantineRobustAggregator(0, n, "absolute", clipRadius, 1, 0);
198+
const ids = honestIds.concat("byz");
199+
aggByz.setNodes(Set(ids));
200+
201+
const pByz = aggByz.getPromiseForAggregation();
202+
honest.forEach((v, i) => aggByz.add(`h${i}`, WeightsContainer.of([v]), 0));
203+
aggByz.add("byz", WeightsContainer.of([1e9]), 0);
204+
const byzOut = await pByz;
205+
const byz = (await byzOut.weights[0].data())[0];
206+
207+
const deviation = Math.abs(byz - clean);
208+
const maxAllowed = 2 * clipRadius / n; // realistic tolerance for extreme inputs
209+
expect(deviation).toBeLessThanOrEqual(maxAllowed);
210+
}
211+
),
212+
{ numRuns: 200 }
213+
);
214+
});
215+
216+
it("is invariant to client ordering", async () => {
217+
const values = [0, 1, 100];
218+
const ids1 = ["a", "b", "c"];
219+
const ids2 = ["c", "a", "b"];
220+
221+
const run = async (ids: string[]) => {
222+
const agg = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 3, 0);
223+
agg.setNodes(Set(ids));
224+
const p = agg.getPromiseForAggregation();
225+
ids.forEach((id, i) =>
226+
agg.add(id, WeightsContainer.of([values[i]]), 0)
227+
);
228+
return (await (await p).weights[0].data())[0];
229+
};
230+
231+
const out1 = await run(ids1);
232+
const out2 = await run(ids2);
233+
234+
expect(out1).to.be.closeTo(out2, 1e-6);
235+
});
236+
237+
it("is idempotent when all inputs are identical and within clipping radius", async () => {
238+
const agg = new ByzantineRobustAggregator(0, 5, "absolute", 10.0, 5, 0);
239+
const ids = ["a", "b", "c", "d", "e"];
240+
agg.setNodes(Set(ids));
241+
242+
const p = agg.getPromiseForAggregation();
243+
ids.forEach(id => agg.add(id, WeightsContainer.of([3.14]), 0));
244+
const out = await p;
245+
246+
const v = (await out.weights[0].data())[0];
247+
expect(v).to.be.closeTo(3.14, 1e-6);
248+
});
249+
250+
it("limits bias under symmetric Byzantine attacks", async () => {
251+
const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0);
252+
agg.setNodes(Set(["h1", "h2", "b1", "b2"]));
253+
254+
const p = agg.getPromiseForAggregation();
255+
agg.add("h1", WeightsContainer.of([1]), 0);
256+
agg.add("h2", WeightsContainer.of([1]), 0);
257+
agg.add("b1", WeightsContainer.of([100]), 0);
258+
agg.add("b2", WeightsContainer.of([-100]), 0);
259+
260+
const out = await p;
261+
const v = (await out.weights[0].data())[0];
262+
263+
expect(Math.abs(v - 1)).to.be.lessThan(0.3);
264+
});
265+
266+
it("output lies within the range of clipped inputs", async () => {
267+
const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0);
268+
agg.setNodes(Set(["a", "b", "c", "d"]));
269+
270+
const p = agg.getPromiseForAggregation();
271+
agg.add("a", WeightsContainer.of([0]), 0);
272+
agg.add("b", WeightsContainer.of([0.5]), 0);
273+
agg.add("c", WeightsContainer.of([1]), 0);
274+
agg.add("d", WeightsContainer.of([100]), 0);
275+
276+
const out = await p;
277+
const v = (await out.weights[0].data())[0];
278+
279+
expect(v).to.be.greaterThanOrEqual(0);
280+
expect(v).to.be.lessThanOrEqual(1);
281+
});
282+
283+
it("single client cannot dominate aggregation", async () => {
284+
const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0);
285+
agg.setNodes(Set(["h1", "h2", "h3", "b"]));
286+
287+
const p = agg.getPromiseForAggregation();
288+
agg.add("h1", WeightsContainer.of([0]), 0);
289+
agg.add("h2", WeightsContainer.of([0]), 0);
290+
agg.add("h3", WeightsContainer.of([0]), 0);
291+
agg.add("b", WeightsContainer.of([1e9]), 0);
292+
293+
const out = await p;
294+
const v = (await out.weights[0].data())[0];
295+
296+
expect(Math.abs(v)).to.be.lessThan(0.5);
297+
});
298+
299+
it("reset state when starting fresh aggregator", async () => {
300+
const run = async () => {
301+
const agg = new ByzantineRobustAggregator(0, 2, "absolute", 1.0, 3, 0.9);
302+
agg.setNodes(Set(["a", "b"]));
303+
const p = agg.getPromiseForAggregation();
304+
agg.add("a", WeightsContainer.of([1]), 0);
305+
agg.add("b", WeightsContainer.of([1]), 0);
306+
return (await (await p).weights[0].data())[0];
307+
};
308+
309+
expect(await run()).to.be.closeTo(await run(), 1e-6);
310+
});
103311
});

0 commit comments

Comments
 (0)