11import path from "node:path" ;
2- import { Dataset , processing } from "@epfml/discojs" ;
3- import type {
2+ import { promises as fs } from "fs" ;
3+ import { Dataset , processing , defaultTasks } from "@epfml/discojs" ;
4+ import {
45 DataFormat ,
56 DataType ,
67 Image ,
@@ -9,18 +10,20 @@ import type {
910import { loadCSV , loadImage , loadImagesInDir } from "@epfml/discojs-node" ;
1011import { Repeat } from "immutable" ;
1112
12- async function loadSimpleFaceData ( ) : Promise < Dataset < DataFormat . Raw [ "image" ] > > {
13+ async function loadSimpleFaceData ( userIdx : number , totalClient : number ) : Promise < Dataset < DataFormat . Raw [ "image" ] > > {
1314 const folder = path . join ( ".." , "datasets" , "simple_face" ) ;
1415
1516 const [ adults , childs ] : Dataset < [ Image , string ] > [ ] = [
1617 ( await loadImagesInDir ( path . join ( folder , "adult" ) ) ) . zip ( Repeat ( "adult" ) ) ,
1718 ( await loadImagesInDir ( path . join ( folder , "child" ) ) ) . zip ( Repeat ( "child" ) ) ,
1819 ] ;
1920
20- return adults . chain ( childs ) ;
21+ const combinded = adults . chain ( childs ) ;
22+
23+ return combinded . filter ( ( _ , i ) => i % totalClient === userIdx ) ;
2124}
2225
23- async function loadLusCovidData ( ) : Promise < Dataset < DataFormat . Raw [ "image" ] > > {
26+ async function loadLusCovidData ( userIdx : number , totalClient : number ) : Promise < Dataset < DataFormat . Raw [ "image" ] > > {
2427 const folder = path . join ( ".." , "datasets" , "lus_covid" ) ;
2528
2629 const [ positive , negative ] : Dataset < [ Image , string ] > [ ] = [
@@ -32,7 +35,11 @@ async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
3235 ) ,
3336 ] ;
3437
35- return positive . chain ( negative ) ;
38+ const combined : Dataset < [ Image , string ] > = positive . chain ( negative ) ;
39+
40+ const sharded = combined . filter ( ( _ , i ) => i % totalClient === userIdx ) ;
41+
42+ return sharded ;
3643}
3744
3845function loadTinderDogData ( split : number ) : Dataset < DataFormat . Raw [ "image" ] > {
@@ -59,25 +66,89 @@ function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
5966 } ) ;
6067}
6168
69+ async function loadExtCifar10 ( userIdx : number ) : Promise < Dataset < [ Image , string ] > > {
70+ const CIFAR10_LABELS = Array . from ( await defaultTasks . cifar10 . getTask ( ) . then ( t => t . trainingInformation . LABEL_LIST ) ) ;
71+ const folder = path . join ( ".." , "datasets" , "extended_cifar10" ) ;
72+ const clientFolder = path . join ( folder , `client_${ userIdx } ` ) ;
73+
74+ return new Dataset ( async function * ( ) {
75+ const entries = await fs . readdir ( clientFolder , { withFileTypes : true } ) ;
76+
77+ const items = entries
78+ . flatMap ( ( e ) => {
79+ const m = e . name . match (
80+ / ^ i m a g e _ ( \d + ) _ l a b e l _ ( \d + ) \. p n g $ / i
81+ ) ;
82+ if ( m === null ) return [ ] ;
83+ const labelIdx = Number . parseInt ( m [ 2 ] , 10 ) ;
84+
85+ if ( labelIdx >= CIFAR10_LABELS . length )
86+ throw new Error ( `${ e . name } : too big label index` ) ;
87+
88+ return {
89+ name : e . name ,
90+ label : CIFAR10_LABELS [ labelIdx ] ,
91+ } ;
92+ } )
93+ . filter ( ( x ) => x !== null )
94+
95+ for ( const { name, label} of items ) {
96+ const filePath = path . join ( clientFolder , name ) ;
97+ const image = await loadImage ( filePath ) ;
98+ yield [ image , label ] as const ;
99+ }
100+ } )
101+ }
102+
103+ function loadMnistData ( split : number ) : Dataset < DataFormat . Raw [ "image" ] > {
104+ const folder = path . join ( ".." , "datasets" , "mnist" , `${ split + 1 } ` ) ;
105+ return loadCSV ( path . join ( folder , "labels.csv" ) )
106+ . map (
107+ ( row ) =>
108+ [
109+ processing . extractColumn ( row , "filename" ) ,
110+ processing . extractColumn ( row , "label" ) ,
111+ ] as const ,
112+ )
113+ . map ( async ( [ filename , label ] ) => {
114+ try {
115+ const image = await Promise . any (
116+ [ "png" , "jpg" , "jpeg" ] . map ( ( ext ) =>
117+ loadImage ( path . join ( folder , `${ filename } .${ ext } ` ) ) ,
118+ ) ,
119+ ) ;
120+ return [ image , label ] ;
121+ } catch {
122+ throw Error ( `${ filename } not found in ${ folder } ` ) ;
123+ }
124+ } ) ;
125+ }
126+
62127export async function getTaskData < D extends DataType > (
63128 taskID : Task . ID ,
64129 userIdx : number ,
130+ totalClient : number
65131) : Promise < Dataset < DataFormat . Raw [ D ] > > {
66132 switch ( taskID ) {
67133 case "simple_face" :
68- return ( await loadSimpleFaceData ( ) ) as Dataset < DataFormat . Raw [ D ] > ;
134+ return ( await loadSimpleFaceData ( userIdx , totalClient ) ) as Dataset < DataFormat . Raw [ D ] > ;
69135 case "titanic" :
70- return loadCSV (
136+ const titanicData = loadCSV (
71137 path . join ( ".." , "datasets" , "titanic_train.csv" ) ,
72138 ) as Dataset < DataFormat . Raw [ D ] > ;
139+ return titanicData . filter ( ( _ , i ) => i % totalClient === userIdx ) ;
73140 case "cifar10" :
74141 return (
75142 await loadImagesInDir ( path . join ( ".." , "datasets" , "CIFAR10" ) )
76143 ) . zip ( Repeat ( "cat" ) ) as Dataset < DataFormat . Raw [ D ] > ;
77144 case "lus_covid" :
78- return ( await loadLusCovidData ( ) ) as Dataset < DataFormat . Raw [ D ] > ;
145+ return ( await loadLusCovidData ( userIdx , totalClient ) ) as Dataset < DataFormat . Raw [ D ] > ;
79146 case "tinder_dog" :
80147 return loadTinderDogData ( userIdx ) as Dataset < DataFormat . Raw [ D ] > ;
148+ case "extended_cifar10" :
149+ return ( await loadExtCifar10 ( userIdx ) ) as Dataset < DataFormat . Raw [ D ] > ;
150+ case "mnist" :
151+ return loadMnistData ( userIdx ) as Dataset < DataFormat . Raw [ D ] > ;
81152 default :
82153 throw new Error ( `Data loader for ${ taskID } not implemented.` ) ;
83154 }
0 commit comments