|
| 1 | +import fc from "fast-check"; |
1 | 2 | import { Set } from "immutable"; |
2 | 3 | import { describe, expect, it } from "vitest"; |
3 | 4 |
|
@@ -100,4 +101,211 @@ describe("ByzantineRobustAggregator", () => { |
100 | 101 | const arr2 = await WSIntoArrays(out2); |
101 | 102 | expect(arr2[0][0]).to.equal(20); |
102 | 103 | }); |
| 104 | + |
| 105 | + it("applies momentum before aggregation", async () => { |
| 106 | + const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1e6, 1, 0.5); |
| 107 | + const [a, b] = ["a", "b"]; |
| 108 | + agg.setNodes(Set.of(a, b)); |
| 109 | + |
| 110 | + const p1 = agg.getPromiseForAggregation(); |
| 111 | + agg.add(a, WeightsContainer.of([0]), 0); |
| 112 | + agg.add(b, WeightsContainer.of([10]), 0); |
| 113 | + await p1; |
| 114 | + |
| 115 | + const p2 = agg.getPromiseForAggregation(); |
| 116 | + agg.add(a, WeightsContainer.of([0]), 1); |
| 117 | + agg.add(b, WeightsContainer.of([20]), 1); |
| 118 | + const out = await p2; |
| 119 | + |
| 120 | + const arr = await WSIntoArrays(out); |
| 121 | + expect(arr[0][0]).to.be.closeTo(7.5, 1e-6); // mean of (0, 15) |
| 122 | + }); |
| 123 | + |
| 124 | + it("beta = 1 freezes aggregation after first round", async () => { |
| 125 | + const agg = new ByzantineRobustAggregator(0, 1, 'absolute', 1e6, 1, 1); |
| 126 | + const id = "c1"; |
| 127 | + agg.setNodes(Set.of(id)); |
| 128 | + |
| 129 | + const p1 = agg.getPromiseForAggregation(); |
| 130 | + agg.add(id, WeightsContainer.of([5]), 0); |
| 131 | + await p1; |
| 132 | + |
| 133 | + const p2 = agg.getPromiseForAggregation(); |
| 134 | + agg.add(id, WeightsContainer.of([100]), 1); |
| 135 | + const out = await p2; |
| 136 | + |
| 137 | + const arr = await WSIntoArrays(out); |
| 138 | + expect(arr[0][0]).to.equal(5); |
| 139 | + }); |
| 140 | + |
| 141 | + it("remains robust with 30% Byzantine clients", async () => { |
| 142 | + const honest = Array(7).fill(1); |
| 143 | + const byzantine = Array(3).fill(100); |
| 144 | + |
| 145 | + const agg = new ByzantineRobustAggregator(0, 10, 'absolute', 1.0, 5, 0); |
| 146 | + const ids = [...honest, ...byzantine].map((_, i) => `c${i}`); |
| 147 | + agg.setNodes(Set(ids)); |
| 148 | + |
| 149 | + const p = agg.getPromiseForAggregation(); |
| 150 | + honest.forEach((v, i) => agg.add(`c${i}`, WeightsContainer.of([v]), 0)); |
| 151 | + byzantine.forEach((v, i) => agg.add(`c${i + honest.length}`, WeightsContainer.of([v]), 0)); |
| 152 | + |
| 153 | + const out = await p; |
| 154 | + const arr = await WSIntoArrays(out); |
| 155 | + |
| 156 | + const honestMean = honest.reduce((a, b) => a + b, 0) / honest.length; |
| 157 | + const rawMean = [...honest, ...byzantine].reduce((a, b) => a + b, 0) / (honest.length + byzantine.length); |
| 158 | + |
| 159 | + expect(Math.abs(arr[0][0] - honestMean)).to.be.lessThan(Math.abs(rawMean - honestMean)); |
| 160 | + }); |
| 161 | + |
| 162 | + it("stays close to the honest signal under constant input", async () => { |
| 163 | + const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 3, 0.9); |
| 164 | + const ids = ["a", "b", "c"]; |
| 165 | + agg.setNodes(Set(ids)); |
| 166 | + |
| 167 | + for (let r = 0; r < 10; r++) { |
| 168 | + const p = agg.getPromiseForAggregation(); |
| 169 | + ids.forEach(id => agg.add(id, WeightsContainer.of([1]), r)); |
| 170 | + const out = await p; |
| 171 | + const arr = await WSIntoArrays(out); |
| 172 | + |
| 173 | + expect(Math.abs(arr[0][0] - 1)).to.be.lessThan(0.3); |
| 174 | + } |
| 175 | + }); |
| 176 | + |
| 177 | + it("bounds the marginal influence of a single Byzantine client", async () => { |
| 178 | + const clipRadius = 1.0; |
| 179 | + |
| 180 | + await fc.assert( |
| 181 | + fc.asyncProperty( |
| 182 | + fc.array(fc.double({ min: -1, max: 1 }), { minLength: 3, maxLength: 10 }), |
| 183 | + async (honest) => { |
| 184 | + const n = honest.length + 1; |
| 185 | + |
| 186 | + // aggregation without Byzantine |
| 187 | + const aggClean = new ByzantineRobustAggregator(0, honest.length, "absolute", clipRadius, 1, 0); |
| 188 | + const honestIds = honest.map((_, i) => `h${i}`); |
| 189 | + aggClean.setNodes(Set(honestIds)); |
| 190 | + |
| 191 | + const pClean = aggClean.getPromiseForAggregation(); |
| 192 | + honest.forEach((v, i) => aggClean.add(`h${i}`, WeightsContainer.of([v]), 0)); |
| 193 | + const cleanOut = await pClean; |
| 194 | + const clean = (await cleanOut.weights[0].data())[0]; |
| 195 | + |
| 196 | + // aggregation with Byzantine |
| 197 | + const aggByz = new ByzantineRobustAggregator(0, n, "absolute", clipRadius, 1, 0); |
| 198 | + const ids = honestIds.concat("byz"); |
| 199 | + aggByz.setNodes(Set(ids)); |
| 200 | + |
| 201 | + const pByz = aggByz.getPromiseForAggregation(); |
| 202 | + honest.forEach((v, i) => aggByz.add(`h${i}`, WeightsContainer.of([v]), 0)); |
| 203 | + aggByz.add("byz", WeightsContainer.of([1e9]), 0); |
| 204 | + const byzOut = await pByz; |
| 205 | + const byz = (await byzOut.weights[0].data())[0]; |
| 206 | + |
| 207 | + const deviation = Math.abs(byz - clean); |
| 208 | + const maxAllowed = 2 * clipRadius / n; // realistic tolerance for extreme inputs |
| 209 | + expect(deviation).toBeLessThanOrEqual(maxAllowed); |
| 210 | + } |
| 211 | + ), |
| 212 | + { numRuns: 200 } |
| 213 | + ); |
| 214 | + }); |
| 215 | + |
| 216 | + it("is invariant to client ordering", async () => { |
| 217 | + const values = [0, 1, 100]; |
| 218 | + const ids1 = ["a", "b", "c"]; |
| 219 | + const ids2 = ["c", "a", "b"]; |
| 220 | + |
| 221 | + const run = async (ids: string[]) => { |
| 222 | + const agg = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 3, 0); |
| 223 | + agg.setNodes(Set(ids)); |
| 224 | + const p = agg.getPromiseForAggregation(); |
| 225 | + ids.forEach((id, i) => |
| 226 | + agg.add(id, WeightsContainer.of([values[i]]), 0) |
| 227 | + ); |
| 228 | + return (await (await p).weights[0].data())[0]; |
| 229 | + }; |
| 230 | + |
| 231 | + const out1 = await run(ids1); |
| 232 | + const out2 = await run(ids2); |
| 233 | + |
| 234 | + expect(out1).to.be.closeTo(out2, 1e-6); |
| 235 | + }); |
| 236 | + |
| 237 | + it("is idempotent when all inputs are identical and within clipping radius", async () => { |
| 238 | + const agg = new ByzantineRobustAggregator(0, 5, "absolute", 10.0, 5, 0); |
| 239 | + const ids = ["a", "b", "c", "d", "e"]; |
| 240 | + agg.setNodes(Set(ids)); |
| 241 | + |
| 242 | + const p = agg.getPromiseForAggregation(); |
| 243 | + ids.forEach(id => agg.add(id, WeightsContainer.of([3.14]), 0)); |
| 244 | + const out = await p; |
| 245 | + |
| 246 | + const v = (await out.weights[0].data())[0]; |
| 247 | + expect(v).to.be.closeTo(3.14, 1e-6); |
| 248 | + }); |
| 249 | + |
| 250 | + it("limits bias under symmetric Byzantine attacks", async () => { |
| 251 | + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); |
| 252 | + agg.setNodes(Set(["h1", "h2", "b1", "b2"])); |
| 253 | + |
| 254 | + const p = agg.getPromiseForAggregation(); |
| 255 | + agg.add("h1", WeightsContainer.of([1]), 0); |
| 256 | + agg.add("h2", WeightsContainer.of([1]), 0); |
| 257 | + agg.add("b1", WeightsContainer.of([100]), 0); |
| 258 | + agg.add("b2", WeightsContainer.of([-100]), 0); |
| 259 | + |
| 260 | + const out = await p; |
| 261 | + const v = (await out.weights[0].data())[0]; |
| 262 | + |
| 263 | + expect(Math.abs(v - 1)).to.be.lessThan(0.3); |
| 264 | + }); |
| 265 | + |
| 266 | + it("output lies within the range of clipped inputs", async () => { |
| 267 | + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); |
| 268 | + agg.setNodes(Set(["a", "b", "c", "d"])); |
| 269 | + |
| 270 | + const p = agg.getPromiseForAggregation(); |
| 271 | + agg.add("a", WeightsContainer.of([0]), 0); |
| 272 | + agg.add("b", WeightsContainer.of([0.5]), 0); |
| 273 | + agg.add("c", WeightsContainer.of([1]), 0); |
| 274 | + agg.add("d", WeightsContainer.of([100]), 0); |
| 275 | + |
| 276 | + const out = await p; |
| 277 | + const v = (await out.weights[0].data())[0]; |
| 278 | + |
| 279 | + expect(v).to.be.greaterThanOrEqual(0); |
| 280 | + expect(v).to.be.lessThanOrEqual(1); |
| 281 | + }); |
| 282 | + |
| 283 | + it("single client cannot dominate aggregation", async () => { |
| 284 | + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); |
| 285 | + agg.setNodes(Set(["h1", "h2", "h3", "b"])); |
| 286 | + |
| 287 | + const p = agg.getPromiseForAggregation(); |
| 288 | + agg.add("h1", WeightsContainer.of([0]), 0); |
| 289 | + agg.add("h2", WeightsContainer.of([0]), 0); |
| 290 | + agg.add("h3", WeightsContainer.of([0]), 0); |
| 291 | + agg.add("b", WeightsContainer.of([1e9]), 0); |
| 292 | + |
| 293 | + const out = await p; |
| 294 | + const v = (await out.weights[0].data())[0]; |
| 295 | + |
| 296 | + expect(Math.abs(v)).to.be.lessThan(0.5); |
| 297 | + }); |
| 298 | + |
| 299 | + it("reset state when starting fresh aggregator", async () => { |
| 300 | + const run = async () => { |
| 301 | + const agg = new ByzantineRobustAggregator(0, 2, "absolute", 1.0, 3, 0.9); |
| 302 | + agg.setNodes(Set(["a", "b"])); |
| 303 | + const p = agg.getPromiseForAggregation(); |
| 304 | + agg.add("a", WeightsContainer.of([1]), 0); |
| 305 | + agg.add("b", WeightsContainer.of([1]), 0); |
| 306 | + return (await (await p).weights[0].data())[0]; |
| 307 | + }; |
| 308 | + |
| 309 | + expect(await run()).to.be.closeTo(await run(), 1e-6); |
| 310 | + }); |
103 | 311 | }); |
0 commit comments