File tree Expand file tree Collapse file tree 6 files changed +9
-9
lines changed
Expand file tree Collapse file tree 6 files changed +9
-9
lines changed Original file line number Diff line number Diff line change @@ -79,7 +79,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {
7979 task . trainingInformation . maxSequenceLength = contextLength
8080 const dataset = loadText ( '../datasets/wikitext/wiki.train.tokens' )
8181 . map ( text => processing . tokenize ( tokenizer , text ) )
82- . unbatch ( )
82+ . flat ( )
8383 . batchWithOverlap ( config . blockSize )
8484
8585 const preprocessedDataset = dataset
Original file line number Diff line number Diff line change @@ -21,7 +21,7 @@ async function main(): Promise<void> {
2121
2222 const tokenDataset = new Dataset ( [ data ] )
2323 . map ( ( text : string ) => processing . tokenize ( tokenizer , text ) )
24- . unbatch ( )
24+ . flat ( )
2525 . batchWithOverlap ( config . blockSize )
2626 . map ( ( tokens ) => [ tokens . pop ( ) , tokens . last ( ) ] as [ List < number > , number ] )
2727 . repeat ( )
Original file line number Diff line number Diff line change @@ -155,7 +155,7 @@ describe("dataset", () => {
155155 const blockSize = 4
156156
157157 const parsed = new Dataset ( [ expectedTokens ] )
158- . unbatch ( )
158+ . flat ( )
159159 . batchWithOverlap ( blockSize )
160160
161161 // -1 because the last sequence is dropped as there is no next token label
Original file line number Diff line number Diff line change @@ -184,8 +184,8 @@ export class Dataset<T> implements AsyncIterable<T> {
184184 ) ;
185185 }
186186
187- /** Flatten chunks */
188- unbatch < U > ( this : Dataset < Batched < U > > ) : Dataset < U > {
187+ /** Flatten batches/arrays of elements */
188+ flat < U > ( this : Dataset < Batched < U > > ) : Dataset < U > {
189189 return new Dataset (
190190 async function * ( this : Dataset < Batched < U > > ) {
191191 for await ( const batch of this ) yield * batch ;
Original file line number Diff line number Diff line change @@ -60,7 +60,7 @@ export async function preprocess<D extends DataType>(
6060
6161 const tokenizer = await models . getTaskTokenizer ( t ) ;
6262 return d . map ( text => processing . tokenize ( tokenizer , text ) )
63- . unbatch ( )
63+ . flat ( )
6464 . batchWithOverlap ( blockSize )
6565 . map ( ( tokens ) => [ tokens . pop ( ) , tokens . last ( ) ] ) as
6666 Dataset < DataFormat . ModelEncoded [ D ] > ;
@@ -101,7 +101,7 @@ export async function preprocessWithoutLabel<D extends DataType>(
101101 const tokenizer = await models . getTaskTokenizer ( t ) ;
102102
103103 return d . map ( text => processing . tokenize ( tokenizer , text ) )
104- . unbatch ( )
104+ . flat ( )
105105 . batch ( blockSize )
106106 }
107107 }
Original file line number Diff line number Diff line change @@ -22,7 +22,7 @@ export class Validator<D extends DataType> {
2222 . zip ( batch . map ( ( [ _ , outputs ] ) => outputs ) )
2323 . map ( ( [ inferred , truth ] ) => inferred === truth ) ,
2424 )
25- . unbatch ( ) ;
25+ . flat ( ) ;
2626
2727 for await ( const e of results ) yield e ;
2828 }
@@ -36,7 +36,7 @@ export class Validator<D extends DataType> {
3636 )
3737 . batch ( this . task . trainingInformation . batchSize )
3838 . map ( ( batch ) => this . #model. predict ( batch ) )
39- . unbatch ( ) ;
39+ . flat ( ) ;
4040
4141 const predictions = await processing . postprocess (
4242 this . task ,
You can’t perform that action at this time.
0 commit comments