Skip to content

Commit c477bb3

Browse files
committed
*: replace line by line text loaders by chunk by chunk text loaders
discojs/src/dataset: implement and test repeat and batchWithOverlap
1 parent c6beac9 commit c477bb3

File tree

15 files changed

+313
-172
lines changed

15 files changed

+313
-172
lines changed

cli/src/benchmark_gpt.ts

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import '@tensorflow/tfjs-node';
12
import { List } from "immutable";
23
import { parse } from "ts-command-line-args";
34
import { AutoTokenizer } from "@xenova/transformers";
@@ -41,7 +42,7 @@ const args = { ...defaultArgs, ...parsedArgs }
4142
* Benchmark results are reported in https://github.com/epfml/disco/pull/659
4243
*/
4344

44-
async function main(args: Required<CLIArguments>): Promise<void> {
45+
async function main(args: Required<CLIArguments>): Promise<void> {
4546
const { inference: benchmarkInference, modelType,
4647
contextLength, batchSize, modelPath } = args
4748

@@ -77,10 +78,11 @@ async function main(args: Required<CLIArguments>): Promise<void> {
7778
task.trainingInformation.batchSize = batchSize
7879
task.trainingInformation.maxSequenceLength = contextLength
7980
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')
81+
.map(text => processing.tokenize(tokenizer, text))
82+
.unbatch()
83+
.batchWithOverlap(config.blockSize)
8084

81-
const maxLength = task.trainingInformation.maxSequenceLength ?? (tokenizer.model_max_length as number) + 1
8285
const preprocessedDataset = dataset
83-
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, maxLength))
8486
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
8587
.batch(batchSize);
8688

@@ -111,10 +113,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {
111113
const iterations = 10
112114
console.log("Generating", maxNewTokens, "new tokens")
113115

114-
let tokens = List(
115-
(tokenizer(prompt, { return_tensor: false }) as { input_ids: number[] })
116-
.input_ids,
117-
);
116+
let tokens = processing.tokenize(tokenizer, prompt);
118117

119118
let inferenceTime = 0
120119
for (let i = 0; i < iterations; i++) {

cli/src/train_gpt.ts

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import * as tf from "@tensorflow/tfjs-node"
1+
import "@tensorflow/tfjs-node"
22
import { AutoTokenizer } from "@xenova/transformers";
3-
import { models, processing } from "@epfml/discojs";
3+
import { models, processing, Dataset } from "@epfml/discojs";
4+
import { List } from "immutable";
45

56
async function main(): Promise<void> {
67
const data = "Lorem ipsum dolor sit amet, consectetur adipis"
7-
const datasetSource = new tf.data.FileDataSource(Buffer.from(data))
8-
const textDataset = new tf.data.TextLineDataset(datasetSource)
8+
const seed = 42
99

1010
const config: models.GPTConfig = {
1111
modelType: 'gpt-nano',
@@ -14,25 +14,34 @@ async function main(): Promise<void> {
1414
evaluateEvery:50,
1515
maxEvalBatches: 10,
1616
blockSize: 16,
17-
vocabSize: 50257,
18-
debug: false
17+
seed
1918
}
2019

2120
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2')
22-
const tokenDataset = textDataset.map((text: string) => {
23-
const tokens = processing.tokenizeAndLeftPad(text, tokenizer, config.blockSize + 1)
24-
const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length)
25-
const xs = tf.tensor(tokens.slice(0, config.blockSize), undefined, 'int32')
26-
return {xs, ys}
27-
}).repeat().batch(16) as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>
21+
22+
const tokenDataset = new Dataset([data])
23+
.map((text: string) => processing.tokenize(tokenizer, text))
24+
.unbatch()
25+
.batchWithOverlap(config.blockSize)
26+
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
27+
.repeat()
28+
.batch(8);
2829

2930
const model = new models.GPT(config)
30-
3131
for await (const logs of model.train(tokenDataset, undefined)) {
3232
console.log(logs)
3333
}
3434

35-
const generation = await model.generate("Lorem", tokenizer, { maxNewTokens: 10, doSample: false, topk: 5, temperature:0.1 })
35+
let tokens = processing.tokenize(tokenizer, "Lorem");
36+
37+
const maxNewTokens = 14
38+
for (let n = 0; n < maxNewTokens; n++) {
39+
const next: number = (await model.predict(
40+
List.of(tokens), { seed })
41+
).first();
42+
tokens = tokens.push(next)
43+
}
44+
const generation = tokenizer.decode(tokens.toArray(), { skip_special_tokens: true })
3645
console.log(generation)
3746
}
3847

discojs-node/src/loaders.spec.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ describe("image directory parser", () => {
5151

5252
describe("text parser", () => {
5353
it("parses basic file", async () => {
54+
const text = ["a", "b", "c"].join("\n")
5455
await withFile(async ({ path }) => {
55-
await fs.writeFile(path, ["a", "b", "c"].join("\n"));
56-
57-
const parsed = loadText(path);
58-
59-
expect(await parsed.size()).to.equal(3);
56+
await fs.writeFile(path, text);
57+
58+
const sequences = await arrayFromAsync(loadText(path))
59+
expect(sequences.length).to.equal(1);
60+
expect(sequences[0]).to.equal(text);
6061
});
6162
});
6263
});

discojs-node/src/loaders/text.ts

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
1-
import * as fs from "node:fs/promises";
2-
import * as readline from "node:readline/promises";
3-
1+
import createDebug from "debug";
2+
import { createReadStream } from 'node:fs';
43
import { Dataset, Text } from "@epfml/discojs";
54

6-
export function load(path: string): Dataset<Text> {
5+
const debug = createDebug("discojs-node:loaders:text");
6+
7+
/**
8+
* Returns chunks of text. Use `minChunkSize` to ensure that
9+
* each chunk is bigger than the expected sequence length.
10+
*
11+
* @param path path to the text file to read
12+
* @param minChunkSize default to 16KiB, the minimum size of each chunk in bits
13+
* @returns a dataset of tokenized input and label sequences
14+
*/
15+
export function load(path: string, minChunkSize = 16384): Dataset<Text> {
716
return new Dataset(async function* () {
8-
const input = (await fs.open(path)).createReadStream({ encoding: "utf8" });
17+
if (minChunkSize < 1 || !Number.isInteger(minChunkSize))
18+
throw new Error("minChunkSize must be positive integers");
19+
20+
debug("Setting the chunk size to %o bits", minChunkSize)
21+
// Create a stream to read the text file chunk by chunk
22+
const stream = createReadStream(path, {
23+
encoding: "utf8",
24+
highWaterMark: minChunkSize
25+
});
26+
for await (const chunk of stream) {
27+
if (typeof chunk !== 'string')
28+
throw new Error('Expected file stream to yield string')
929

10-
// `readline` is a bit overkill but seems standard
11-
// https://nodejs.org/api/readline.html#example-read-file-stream-line-by-line
12-
yield* readline.createInterface({ input, crlfDelay: Infinity });
30+
debug("yield chunk of length: %o", chunk.length);
31+
yield chunk
32+
}
1333
});
1434
}

discojs-web/src/loaders.spec.ts

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { describe, it, expect } from "vitest";
2-
32
import { loadCSV, loadText } from "./loaders/index.js";
43

54
async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
@@ -22,22 +21,16 @@ describe("csv parser", () => {
2221
});
2322

2423
describe("text parser", () => {
25-
it("loads", async () => {
24+
it("loads a simple sequence", async () => {
25+
const text = ["first", "second", "third"].join("\n")
26+
2627
// jsdom doesn't implement .text on File/Blob
2728
// trick from https://github.com/jsdom/jsdom/issues/2555
28-
const text = await (
29-
await fetch(
30-
// data URL content need to be url-encoded
31-
["data:,first", "second", "third"].join("%0A"),
32-
)
29+
const file = await (
30+
await fetch( "data:," + encodeURIComponent(text))
3331
).blob();
34-
35-
const parsed = loadText(text);
36-
37-
expect(await arrayFromAsync(parsed)).to.have.ordered.members([
38-
"first",
39-
"second",
40-
"third",
41-
]);
32+
const parsed = loadText(file)
33+
expect(await parsed.size()).to.equal(1);
34+
expect((await arrayFromAsync(parsed))[0]).to.equal(text);
4235
});
4336
});

discojs-web/src/loaders/text.ts

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,10 @@
11
import { Dataset, Text } from "@epfml/discojs";
22

3-
class LineStream extends TransformStream<string, string> {
4-
constructor() {
5-
let current_line = "";
6-
7-
super({
8-
transform: (chunk, controller) => {
9-
const [head, ...lines] = chunk.split(/\r\n|\r|\n/);
10-
const first_line = current_line + head;
11-
12-
if (lines.length === 0) {
13-
current_line = first_line;
14-
return;
15-
}
16-
17-
controller.enqueue(first_line);
18-
for (const line of lines.slice(0, -1)) controller.enqueue(line);
19-
20-
current_line = lines[lines.length - 1];
21-
},
22-
flush: (controller) => controller.enqueue(current_line),
23-
});
24-
}
25-
}
26-
273
export function load(file: Blob): Dataset<Text> {
284
return new Dataset(async function* () {
295
const reader = file
306
.stream()
317
.pipeThrough(new TextDecoderStream())
32-
.pipeThrough(new LineStream())
338
.getReader();
349

3510
while (true) {

discojs/src/dataset/dataset.spec.ts

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { expect } from "chai";
22
import { Dataset } from "./dataset.js";
3-
import { Range } from "immutable";
3+
import { List, Range } from "immutable";
44

55
// Array.fromAsync not yet widely used (2024)
66
async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
@@ -139,4 +139,58 @@ describe("dataset", () => {
139139
[3, 2],
140140
]);
141141
});
142+
143+
it("batches with overlap", async () => {
144+
const dataset = new Dataset([1, 2, 3]);
145+
146+
const batched = dataset.batchWithOverlap(1);
147+
148+
expect(
149+
(await arrayFromAsync(batched)).map((l) => l.toArray()),
150+
).to.have.deep.ordered.members([[1, 2], [2, 3]]);
151+
});
152+
153+
it("batchWithOverlap yields correct batches", async () => {
154+
const expectedTokens = Range(0, 53).toList()
155+
const blockSize = 4
156+
157+
const parsed = new Dataset([expectedTokens])
158+
.unbatch()
159+
.batchWithOverlap(blockSize)
160+
161+
// -1 because the last sequence is dropped as there is no next token label
162+
const expectedLength = Math.ceil(expectedTokens.size / blockSize) - 1
163+
expect(await parsed.size()).to.equal(expectedLength);
164+
165+
// exclude the last sequence because it has been padded
166+
let sequences = List(await arrayFromAsync(parsed))
167+
// we expect the last sequence to have blockSize + 1 tokens via padding
168+
expect(sequences.last()?.size).to.equal(blockSize + 1)
169+
sequences = sequences.pop()
170+
let i = 0
171+
for await (const tokens of sequences) {
172+
// each sequence has length blockSize + 1 (for the label)
173+
expect(tokens.toArray()).to.deep.equal(
174+
expectedTokens.slice(i, i + blockSize + 1).toArray()
175+
);
176+
// but the window should move by blockSize only
177+
i += blockSize
178+
}
179+
})
180+
181+
it("repeats content infinitely", async () => {
182+
const dataset = new Dataset([0, 1, 2]).repeat();
183+
const iter = dataset[Symbol.asyncIterator]()
184+
185+
for (const i of Range(0, 10)) {
186+
const e = await iter.next()
187+
expect(e.done).to.be.false
188+
expect(e.value).to.equal(i % 3)
189+
}
190+
});
191+
192+
it("repeats content a fixed number of times", async () => {
193+
const dataset = new Dataset([0, 1]).repeat(3);
194+
expect([0,1,0,1,0,1]).to.deep.equal(await arrayFromAsync(dataset))
195+
});
142196
});

discojs/src/dataset/dataset.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,46 @@ export class Dataset<T> implements AsyncIterable<T> {
144144
);
145145
}
146146

147+
/**
148+
* Create batches of size `size + 1` which overlap on one element:
149+
* the last element of one batch is the same as the first element of the next
150+
* Notes:
151+
* - The resulting dataset has a batch size `size`+ 1
152+
* - The last batch is dropped as there are no next element to add.
153+
*
154+
* This method is tailored to create text sequences where each token's label is the following token.
155+
* In order to have a label for the last token of the input sequence, we include the first token
156+
* of the next sequence.
157+
*
158+
* @param size batch size excluding the overlapping element, at least 1
159+
* @returns a dataset batch size `size + 1`
160+
*/
161+
batchWithOverlap(size: number): Dataset<Batched<T>> {
162+
if (size <= 0 || !Number.isInteger(size)) throw new Error("invalid size");
163+
164+
return new Dataset(
165+
async function* (this: Dataset<T>) {
166+
const iter = this.batch(size)[Symbol.asyncIterator]();
167+
// get the first batch
168+
const firstRes = await iter.next()
169+
if (firstRes.done) return
170+
let currentBatch = firstRes.value
171+
for (; ;) {
172+
// get the next batch
173+
const res = await iter.next()
174+
if (res.done) break;
175+
const nextBatch = res.value
176+
// get the first element of the next batch
177+
const nextFirstElement = nextBatch.first()
178+
if (nextFirstElement === undefined) break
179+
// yield the current batch with the first element of the next batch
180+
yield currentBatch.concat(nextFirstElement);
181+
currentBatch = nextBatch
182+
}
183+
}.bind(this),
184+
);
185+
}
186+
147187
/** Flatten chunks */
148188
unbatch<U>(this: Dataset<Batched<U>>): Dataset<U> {
149189
return new Dataset(
@@ -176,6 +216,23 @@ export class Dataset<T> implements AsyncIterable<T> {
176216
);
177217
}
178218

219+
/**
220+
* Repeat the dataset `times` times
221+
* @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
222+
* @returns a dataset repeated `times` times
223+
*/
224+
repeat(times?: number): Dataset<T> {
225+
return new Dataset(
226+
async function* (this: Dataset<T>) {
227+
let loop = 0;
228+
do {
229+
for await (const e of this) yield e;
230+
loop++
231+
} while (times === undefined || loop < times)
232+
}.bind(this),
233+
);
234+
}
235+
179236
/** Compute size
180237
*
181238
* This is a costly operation as we need to go through the whole Dataset.

discojs/src/dataset/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ export type Batched<T> = List<T>;
77
export { Image };
88
export type Tabular = Partial<Record<string, string>>;
99
export type Text = string;
10+
export type TokenizedText = List<number>;

0 commit comments

Comments
 (0)