Skip to content

Commit cb806c0

Browse files
committed
discojs*: rename .unbatch() to .flat()
1 parent c477bb3 commit cb806c0

File tree

6 files changed

+9
-9
lines changed

6 files changed

+9
-9
lines changed

cli/src/benchmark_gpt.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {
7979
task.trainingInformation.maxSequenceLength = contextLength
8080
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')
8181
.map(text => processing.tokenize(tokenizer, text))
82-
.unbatch()
82+
.flat()
8383
.batchWithOverlap(config.blockSize)
8484

8585
const preprocessedDataset = dataset

cli/src/train_gpt.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ async function main(): Promise<void> {
2121

2222
const tokenDataset = new Dataset([data])
2323
.map((text: string) => processing.tokenize(tokenizer, text))
24-
.unbatch()
24+
.flat()
2525
.batchWithOverlap(config.blockSize)
2626
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
2727
.repeat()

discojs/src/dataset/dataset.spec.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ describe("dataset", () => {
155155
const blockSize = 4
156156

157157
const parsed = new Dataset([expectedTokens])
158-
.unbatch()
158+
.flat()
159159
.batchWithOverlap(blockSize)
160160

161161
// -1 because the last sequence is dropped as there is no next token label

discojs/src/dataset/dataset.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ export class Dataset<T> implements AsyncIterable<T> {
184184
);
185185
}
186186

187-
/** Flatten chunks */
188-
unbatch<U>(this: Dataset<Batched<U>>): Dataset<U> {
187+
/** Flatten batches/arrays of elements */
188+
flat<U>(this: Dataset<Batched<U>>): Dataset<U> {
189189
return new Dataset(
190190
async function* (this: Dataset<Batched<U>>) {
191191
for await (const batch of this) yield* batch;

discojs/src/processing/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ export async function preprocess<D extends DataType>(
6060

6161
const tokenizer = await models.getTaskTokenizer(t);
6262
return d.map(text => processing.tokenize(tokenizer, text))
63-
.unbatch()
63+
.flat()
6464
.batchWithOverlap(blockSize)
6565
.map((tokens) => [tokens.pop(), tokens.last()]) as
6666
Dataset<DataFormat.ModelEncoded[D]>;
@@ -101,7 +101,7 @@ export async function preprocessWithoutLabel<D extends DataType>(
101101
const tokenizer = await models.getTaskTokenizer(t);
102102

103103
return d.map(text => processing.tokenize(tokenizer, text))
104-
.unbatch()
104+
.flat()
105105
.batch(blockSize)
106106
}
107107
}

discojs/src/validator.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ export class Validator<D extends DataType> {
2222
.zip(batch.map(([_, outputs]) => outputs))
2323
.map(([inferred, truth]) => inferred === truth),
2424
)
25-
.unbatch();
25+
.flat();
2626

2727
for await (const e of results) yield e;
2828
}
@@ -36,7 +36,7 @@ export class Validator<D extends DataType> {
3636
)
3737
.batch(this.task.trainingInformation.batchSize)
3838
.map((batch) => this.#model.predict(batch))
39-
.unbatch();
39+
.flat();
4040

4141
const predictions = await processing.postprocess(
4242
this.task,

0 commit comments

Comments
 (0)