Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ interface BenchmarkArguments {
epochs: number
roundDuration: number
batchSize: number
validationSplit: number
epsilon?: number
delta?: number
dpDefaultClippingRadius?: number
save: boolean
host: URL
}
Expand All @@ -28,6 +32,10 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 },
epsilon: { type: Number, alias: 'n', description: 'Privacy budget', optional: true, defaultValue: undefined},
delta: { type: Number, alias: 'd', description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined},
dpDefaultClippingRadius: {type: Number, alias: 'f', description: 'Default clipping radius for DP', optional: true, defaultValue: undefined},
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
Expand All @@ -52,6 +60,7 @@ const supportedTasks = Map(
defaultTasks.simpleFace,
defaultTasks.titanic,
defaultTasks.tinderDog,
defaultTasks.mnist,
).map(
async (t) =>
[(await t.getTask()).id, t] as [
Expand All @@ -77,10 +86,29 @@ export const args: BenchmarkArguments = {
task.trainingInformation.batchSize = unsafeArgs.batchSize;
task.trainingInformation.roundDuration = unsafeArgs.roundDuration;
task.trainingInformation.epochs = unsafeArgs.epochs;
task.trainingInformation.validationSplit = unsafeArgs.validationSplit;

// For DP
// TASK.trainingInformation.clippingRadius = 10000000
// TASK.trainingInformation.noiseScale = 0
const {dpDefaultClippingRadius, epsilon, delta} = unsafeArgs;

if (
// dpDefaultClippingRadius !== undefined &&
epsilon !== undefined &&
delta !== undefined
){
if (task.trainingInformation.scheme === "local")
throw new Error("Can't have differential privacy for local training");

const defaultRadius = dpDefaultClippingRadius ? dpDefaultClippingRadius : 1;

// for the case where privacy parameters are not defined in the default tasks
task.trainingInformation.privacy ??= {}
task.trainingInformation.privacy.differentialPrivacy = {
clippingRadius: defaultRadius,
epsilon: epsilon,
delta: delta,
};
}

return task;
},
Expand Down
2 changes: 1 addition & 1 deletion cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async function main<D extends DataType, N extends Network>(
console.log({ args })

const dataSplits = await Promise.all(
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i))
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers))
)
const logs = await Promise.all(
dataSplits.map(async data => await runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>))
Expand Down
97 changes: 88 additions & 9 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import path from "node:path";
import { Dataset, processing } from "@epfml/discojs";
import type {
import { promises as fs } from "fs";
import { Dataset, processing, defaultTasks } from "@epfml/discojs";
import {
DataFormat,
DataType,
Image,
Expand All @@ -9,18 +10,20 @@ import type {
import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
import { Repeat } from "immutable";

async function loadSimpleFaceData(): Promise<Dataset<DataFormat.Raw["image"]>> {
async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "simple_face");

const [adults, childs]: Dataset<[Image, string]>[] = [
(await loadImagesInDir(path.join(folder, "adult"))).zip(Repeat("adult")),
(await loadImagesInDir(path.join(folder, "child"))).zip(Repeat("child")),
];

return adults.chain(childs);
const combinded = adults.chain(childs);

return combinded.filter((_, i) => i % totalClient === userIdx);
}

async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
async function loadLusCovidData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "lus_covid");

const [positive, negative]: Dataset<[Image, string]>[] = [
Expand All @@ -32,7 +35,11 @@ async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
),
];

return positive.chain(negative);
const combined: Dataset<[Image, string]> = positive.chain(negative);

const sharded = combined.filter((_, i) => i % totalClient === userIdx);

return sharded;
}

function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
Expand All @@ -59,25 +66,97 @@ function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
});
}

async function loadExtCifar10(userIdx: number){
const CIFAR10_LABELS = Array.from(await defaultTasks.cifar10.getTask().then(t => t.trainingInformation.LABEL_LIST));
const folder = path.join("..", "datasets", "extended_cifar10");
const clientFolder = path.join(folder, `client_${userIdx}`);

return new Dataset<[Image, string]>(async function*(){
const entries = await fs.readdir(clientFolder, {withFileTypes: true});

const items = entries
.flatMap((e) => {
const m = e.name.match(
/^image_(\d+)_label_(\d+)\.png$/i
);
if (m === null) return [];
const labelIdx = Number.parseInt(m[2], 10);

if(
!Number.isInteger(labelIdx) ||
labelIdx < 0 ||
labelIdx >= CIFAR10_LABELS.length
){
throw new Error('Not a valid label index.');
}

return {
name: e.name,
labelIdx,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you already checked that it was within range of CIFAR10_LABELS, you can directly return the label

};
})
.filter(
(x): x is {idx: number; name: string; labelIdx: number } => x !== null
)

for (const {name, labelIdx} of items){
const label = CIFAR10_LABELS[labelIdx];
const filePath = path.join(clientFolder, name);
const image = await loadImage(filePath);
yield [image, label] as [Image, string];
}
})
}

function loadMnistData(split: number): Dataset<DataFormat.Raw["image"]>{
const folder = path.join("..", "datasets", "mnist", `${split + 1}`);
return loadCSV(path.join(folder, "labels.csv"))
.map(
(row) =>
[
processing.extractColumn(row, "filename"),
processing.extractColumn(row, "label"),
] as const,
)
.map(async ([filename, label]) => {
try {
const image = await Promise.any(
["png", "jpg", "jpeg"].map((ext) =>
loadImage(path.join(folder, `${filename}.${ext}`)),
),
);
return [image, label];
} catch {
throw Error(`${filename} not found in ${folder}`);
}
});
}

export async function getTaskData<D extends DataType>(
taskID: Task.ID,
userIdx: number,
totalClient: number
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (taskID) {
case "simple_face":
return (await loadSimpleFaceData()) as Dataset<DataFormat.Raw[D]>;
return (await loadSimpleFaceData(userIdx, totalClient)) as Dataset<DataFormat.Raw[D]>;
case "titanic":
return loadCSV(
const titanicData = loadCSV(
path.join("..", "datasets", "titanic_train.csv"),
) as Dataset<DataFormat.Raw[D]>;
return titanicData.filter((_, i) => i % totalClient === userIdx);
case "cifar10":
return (
await loadImagesInDir(path.join("..", "datasets", "CIFAR10"))
).zip(Repeat("cat")) as Dataset<DataFormat.Raw[D]>;
case "lus_covid":
return (await loadLusCovidData()) as Dataset<DataFormat.Raw[D]>;
return (await loadLusCovidData(userIdx, totalClient)) as Dataset<DataFormat.Raw[D]>;
case "tinder_dog":
return loadTinderDogData(userIdx) as Dataset<DataFormat.Raw[D]>;
case "extended_cifar10":
return (await loadExtCifar10(userIdx)) as Dataset<DataFormat.Raw[D]>;
case "mnist":
return loadMnistData(userIdx) as Dataset<DataFormat.Raw[D]>;
default:
throw new Error(`Data loader for ${taskID} not implemented.`);
}
Expand Down
54 changes: 54 additions & 0 deletions discojs/src/dataset/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,60 @@ export class Dataset<T> implements AsyncIterable<T> {
cached(): Dataset<T> {
return new CachingDataset(this.#content);
}

/** Shuffles the Dataset instance within certain window size */
shuffle(windowSize: number){
if (!Number.isInteger(windowSize) || windowSize < 1){
throw new Error("Shuffle window size should be a positive integer");
}

return new Dataset(
async function*(this: Dataset<T>){
const iter = this[Symbol.asyncIterator]();
const buffer: T[] = [];

// 1. Construct the initial buffer
while (buffer.length < windowSize){
const n = await iter.next();
if (n.done) break;
buffer.push(n.value);
}

// 2. Shuffle
while (buffer.length > 0){
const pick = Math.floor(Math.random() * buffer.length);
const chosen = buffer[pick];

const n = await iter.next();

if (n.done){
// move the last element to the pick position
buffer[pick] = buffer.pop() as T;
}else{
buffer[pick] = n.value;
}

yield chosen;
}
}.bind(this)
);
}

/** filter the indices according to the splitting condition */
filter(
condition: (value: T, index: number) => boolean | Promise<boolean>
): Dataset<T>{
return new Dataset<T>(async function* (this: Dataset<T>): AsyncGenerator<T, void, unknown>{
let i = 0;
for await(const v of this){
if (await condition(v, i)){
yield v;
}
i += 1
}
}.bind(this));
}

}

/**
Expand Down
12 changes: 9 additions & 3 deletions discojs/src/default_tasks/cifar10.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
},
},
trainingInformation: {
epochs: 10,
epochs: 20,
roundDuration: 10,
validationSplit: 0.2,
batchSize: 10,
Expand All @@ -36,7 +36,13 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
scheme: 'decentralized',
aggregationStrategy: 'mean',
privacy: { clippingRadius: 20, noiseScale: 1 },
privacy: {
differentialPrivacy: {
clippingRadius: 1,
epsilon: 50,
delta: 1e-5,
},
},
minNbOfParticipants: 3,
maxShareValue: 100,
tensorBackend: 'tfjs'
Expand All @@ -63,7 +69,7 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
model.compile({
optimizer: 'sgd',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
metrics: ['accuracy'],
})

return new models.TFJS('image', model)
Expand Down
Loading
Loading