Skip to content

Commit 7eaba26

Browse files
ahzero7d1tharvik
authored andcommitted
Optimized differential privacy implementation
1 parent 434bc3d commit 7eaba26

File tree

18 files changed

+728
-273
lines changed

18 files changed

+728
-273
lines changed

cli/src/args.ts

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ interface BenchmarkArguments {
1010
epochs: number
1111
roundDuration: number
1212
batchSize: number
13+
validationSplit: number
14+
epsilon?: number
15+
delta?: number
16+
dpDefaultClippingRadius?: number
1317
save: boolean
1418
host: URL
1519
}
@@ -28,6 +32,10 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
2832
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
2933
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
3034
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
35+
validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 },
36+
epsilon: { type: Number, alias: 'n', description: 'Privacy budget', optional: true, defaultValue: undefined},
37+
delta: { type: Number, alias: 'd', description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined},
38+
dpDefaultClippingRadius: {type: Number, alias: 'f', description: 'Default clipping radius for DP', optional: true, defaultValue: undefined},
3139
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
3240
host: {
3341
type: (raw: string) => new URL(raw),
@@ -52,6 +60,7 @@ const supportedTasks = Map(
5260
defaultTasks.simpleFace,
5361
defaultTasks.titanic,
5462
defaultTasks.tinderDog,
63+
defaultTasks.mnist,
5564
).map(
5665
async (t) =>
5766
[(await t.getTask()).id, t] as [
@@ -77,10 +86,29 @@ export const args: BenchmarkArguments = {
7786
task.trainingInformation.batchSize = unsafeArgs.batchSize;
7887
task.trainingInformation.roundDuration = unsafeArgs.roundDuration;
7988
task.trainingInformation.epochs = unsafeArgs.epochs;
89+
task.trainingInformation.validationSplit = unsafeArgs.validationSplit;
8090

8191
// For DP
82-
// TASK.trainingInformation.clippingRadius = 10000000
83-
// TASK.trainingInformation.noiseScale = 0
92+
const {dpDefaultClippingRadius, epsilon, delta} = unsafeArgs;
93+
94+
if (
95+
// dpDefaultClippingRadius !== undefined &&
96+
epsilon !== undefined &&
97+
delta !== undefined
98+
){
99+
if (task.trainingInformation.scheme === "local")
100+
throw new Error("Can't have differential privacy for local training");
101+
102+
const defaultRadius = dpDefaultClippingRadius ? dpDefaultClippingRadius : 1;
103+
104+
// for the case where privacy parameters are not defined in the default tasks
105+
task.trainingInformation.privacy ??= {}
106+
task.trainingInformation.privacy.differentialPrivacy = {
107+
clippingRadius: defaultRadius,
108+
epsilon: epsilon,
109+
delta: delta,
110+
};
111+
}
84112

85113
return task;
86114
},

cli/src/cli.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async function main<D extends DataType, N extends Network>(
5151
console.log({ args })
5252

5353
const dataSplits = await Promise.all(
54-
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i))
54+
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers))
5555
)
5656
const logs = await Promise.all(
5757
dataSplits.map(async data => await runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>))

cli/src/data.ts

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import path from "node:path";
2-
import { Dataset, processing } from "@epfml/discojs";
3-
import type {
2+
import { promises as fs } from "fs";
3+
import { Dataset, processing, defaultTasks } from "@epfml/discojs";
4+
import {
45
DataFormat,
56
DataType,
67
Image,
@@ -9,18 +10,20 @@ import type {
910
import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
1011
import { Repeat } from "immutable";
1112

12-
async function loadSimpleFaceData(): Promise<Dataset<DataFormat.Raw["image"]>> {
13+
async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
1314
const folder = path.join("..", "datasets", "simple_face");
1415

1516
const [adults, childs]: Dataset<[Image, string]>[] = [
1617
(await loadImagesInDir(path.join(folder, "adult"))).zip(Repeat("adult")),
1718
(await loadImagesInDir(path.join(folder, "child"))).zip(Repeat("child")),
1819
];
1920

20-
return adults.chain(childs);
21+
const combinded = adults.chain(childs);
22+
23+
return combinded.filter((_, i) => i % totalClient === userIdx);
2124
}
2225

23-
async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
26+
async function loadLusCovidData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
2427
const folder = path.join("..", "datasets", "lus_covid");
2528

2629
const [positive, negative]: Dataset<[Image, string]>[] = [
@@ -32,7 +35,11 @@ async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
3235
),
3336
];
3437

35-
return positive.chain(negative);
38+
const combined: Dataset<[Image, string]> = positive.chain(negative);
39+
40+
const sharded = combined.filter((_, i) => i % totalClient === userIdx);
41+
42+
return sharded;
3643
}
3744

3845
function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
@@ -59,25 +66,89 @@ function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
5966
});
6067
}
6168

69+
async function loadExtCifar10(userIdx: number): Promise<Dataset<[Image, string]>> {
70+
const CIFAR10_LABELS = Array.from(await defaultTasks.cifar10.getTask().then(t => t.trainingInformation.LABEL_LIST));
71+
const folder = path.join("..", "datasets", "extended_cifar10");
72+
const clientFolder = path.join(folder, `client_${userIdx}`);
73+
74+
return new Dataset(async function*(){
75+
const entries = await fs.readdir(clientFolder, {withFileTypes: true});
76+
77+
const items = entries
78+
.flatMap((e) => {
79+
const m = e.name.match(
80+
/^image_(\d+)_label_(\d+)\.png$/i
81+
);
82+
if (m === null) return [];
83+
const labelIdx = Number.parseInt(m[2], 10);
84+
85+
if(labelIdx >= CIFAR10_LABELS.length)
86+
throw new Error(`${e.name}: too big label index`);
87+
88+
return {
89+
name: e.name,
90+
label: CIFAR10_LABELS[labelIdx],
91+
};
92+
})
93+
.filter((x) => x !== null)
94+
95+
for (const {name, label} of items){
96+
const filePath = path.join(clientFolder, name);
97+
const image = await loadImage(filePath);
98+
yield [image, label] as const;
99+
}
100+
})
101+
}
102+
103+
function loadMnistData(split: number): Dataset<DataFormat.Raw["image"]>{
104+
const folder = path.join("..", "datasets", "mnist", `${split + 1}`);
105+
return loadCSV(path.join(folder, "labels.csv"))
106+
.map(
107+
(row) =>
108+
[
109+
processing.extractColumn(row, "filename"),
110+
processing.extractColumn(row, "label"),
111+
] as const,
112+
)
113+
.map(async ([filename, label]) => {
114+
try {
115+
const image = await Promise.any(
116+
["png", "jpg", "jpeg"].map((ext) =>
117+
loadImage(path.join(folder, `${filename}.${ext}`)),
118+
),
119+
);
120+
return [image, label];
121+
} catch {
122+
throw Error(`${filename} not found in ${folder}`);
123+
}
124+
});
125+
}
126+
62127
export async function getTaskData<D extends DataType>(
63128
taskID: Task.ID,
64129
userIdx: number,
130+
totalClient: number
65131
): Promise<Dataset<DataFormat.Raw[D]>> {
66132
switch (taskID) {
67133
case "simple_face":
68-
return (await loadSimpleFaceData()) as Dataset<DataFormat.Raw[D]>;
134+
return (await loadSimpleFaceData(userIdx, totalClient)) as Dataset<DataFormat.Raw[D]>;
69135
case "titanic":
70-
return loadCSV(
136+
const titanicData = loadCSV(
71137
path.join("..", "datasets", "titanic_train.csv"),
72138
) as Dataset<DataFormat.Raw[D]>;
139+
return titanicData.filter((_, i) => i % totalClient === userIdx);
73140
case "cifar10":
74141
return (
75142
await loadImagesInDir(path.join("..", "datasets", "CIFAR10"))
76143
).zip(Repeat("cat")) as Dataset<DataFormat.Raw[D]>;
77144
case "lus_covid":
78-
return (await loadLusCovidData()) as Dataset<DataFormat.Raw[D]>;
145+
return (await loadLusCovidData(userIdx, totalClient)) as Dataset<DataFormat.Raw[D]>;
79146
case "tinder_dog":
80147
return loadTinderDogData(userIdx) as Dataset<DataFormat.Raw[D]>;
148+
case "extended_cifar10":
149+
return (await loadExtCifar10(userIdx)) as Dataset<DataFormat.Raw[D]>;
150+
case "mnist":
151+
return loadMnistData(userIdx) as Dataset<DataFormat.Raw[D]>;
81152
default:
82153
throw new Error(`Data loader for ${taskID} not implemented.`);
83154
}

discojs/src/aggregator/get.ts

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,22 @@ export function getAggregator(
4848
};
4949

5050
switch (task.trainingInformation.aggregationStrategy) {
51-
case 'byzantine': {
52-
const {byzantineClippingRadius = 1.0, maxIterations = 1, beta = 0.9,
53-
} = task.trainingInformation;
51+
case "byzantine": {
52+
const {
53+
clippingRadius = 1.0,
54+
maxIterations = 1,
55+
beta = 0.9,
56+
} = task.trainingInformation.privacy.byzantineFaultTolerance;
5457

55-
return new ByzantineRobustAggregator(
56-
networkOptions.roundCutOff,
57-
networkOptions.threshold,
58-
networkOptions.thresholdType,
59-
byzantineClippingRadius,
60-
maxIterations,
61-
beta
62-
);
63-
}
58+
return new ByzantineRobustAggregator(
59+
networkOptions.roundCutOff,
60+
networkOptions.threshold,
61+
networkOptions.thresholdType,
62+
clippingRadius,
63+
maxIterations,
64+
beta,
65+
);
66+
}
6467
case 'mean':
6568
return new aggregator.MeanAggregator(
6669
networkOptions.roundCutOff,

discojs/src/dataset/dataset.ts

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,60 @@ export class Dataset<T> implements AsyncIterable<T> {
237237
cached(): Dataset<T> {
238238
return new CachingDataset(this.#content);
239239
}
240+
241+
/** Shuffles the Dataset instance within certain window size */
242+
shuffle(windowSize: number){
243+
if (!Number.isInteger(windowSize) || windowSize < 1){
244+
throw new Error("Shuffle window size should be a positive integer");
245+
}
246+
247+
return new Dataset(
248+
async function*(this: Dataset<T>){
249+
const iter = this[Symbol.asyncIterator]();
250+
const buffer: T[] = [];
251+
252+
// 1. Construct the initial buffer
253+
while (buffer.length < windowSize){
254+
const n = await iter.next();
255+
if (n.done) break;
256+
buffer.push(n.value);
257+
}
258+
259+
// 2. Shuffle
260+
while (buffer.length > 0){
261+
const pick = Math.floor(Math.random() * buffer.length);
262+
const chosen = buffer[pick];
263+
264+
const n = await iter.next();
265+
266+
if (n.done){
267+
// move the last element to the pick position
268+
buffer[pick] = buffer.pop() as T;
269+
}else{
270+
buffer[pick] = n.value;
271+
}
272+
273+
yield chosen;
274+
}
275+
}.bind(this)
276+
);
277+
}
278+
279+
/** filter the indices according to the splitting condition */
280+
filter(
281+
condition: (value: T, index: number) => boolean | Promise<boolean>
282+
): Dataset<T>{
283+
return new Dataset<T>(async function* (this: Dataset<T>): AsyncGenerator<T, void, unknown>{
284+
let i = 0;
285+
for await(const v of this){
286+
if (await condition(v, i)){
287+
yield v;
288+
}
289+
i += 1
290+
}
291+
}.bind(this));
292+
}
293+
240294
}
241295

242296
/**

discojs/src/default_tasks/cifar10.ts

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,24 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
2727
},
2828
},
2929
trainingInformation: {
30-
epochs: 10,
30+
epochs: 20,
3131
roundDuration: 10,
3232
validationSplit: 0.2,
3333
batchSize: 10,
3434
IMAGE_H: 224,
3535
IMAGE_W: 224,
3636
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
3737
scheme: 'decentralized',
38-
aggregationStrategy: 'byzantine',
39-
byzantineClippingRadius: 10.0,
4038
maxIterations: 1,
4139
beta: 0.9,
42-
privacy: { clippingRadius: 20, noiseScale: 1 },
40+
aggregationStrategy: 'mean',
41+
privacy: {
42+
differentialPrivacy: {
43+
clippingRadius: 1,
44+
epsilon: 50,
45+
delta: 1e-5,
46+
},
47+
},
4348
minNbOfParticipants: 3,
4449
maxShareValue: 100,
4550
tensorBackend: 'tfjs'
@@ -66,7 +71,7 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
6671
model.compile({
6772
optimizer: 'sgd',
6873
loss: 'categoricalCrossentropy',
69-
metrics: ['accuracy']
74+
metrics: ['accuracy'],
7075
})
7176

7277
return new models.TFJS('image', model)

discojs/src/default_tasks/mnist.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@ export const mnist: TaskProvider<"image", "decentralized"> = {
3333
IMAGE_W: 28,
3434
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
3535
scheme: 'decentralized',
36-
aggregationStrategy: 'mean',
36+
aggregationStrategy: "byzantine",
37+
privacy: {
38+
byzantineFaultTolerance: {
39+
clippingRadius: 10,
40+
maxIterations: 1,
41+
beta: 0.9,
42+
},
43+
},
3744
minNbOfParticipants: 3,
3845
maxShareValue: 100,
3946
tensorBackend: 'tfjs'

0 commit comments

Comments
 (0)