Skip to content

Commit 8cbc96e

Browse files
committed
discojs*,cli*: rename blockSize and maxSequenceLength to contextLength
1 parent cb806c0 commit 8cbc96e

File tree

12 files changed

+41
-41
lines changed

12 files changed

+41
-41
lines changed

cli/src/benchmark_gpt.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,18 @@ async function main(args: Required<CLIArguments>): Promise<void> {
6969
const config: models.GPTConfig = {
7070
modelType: modelType as models.GPTConfig['modelType'],
7171
maxIter: iterationsPerEpoch,
72-
blockSize: contextLength,
7372
lr: 0.0001,
73+
contextLength,
7474
}
7575

7676
// Load the dataset after setting the Task batch size and max sequence length
7777
// to make sure the dataset is batched and tokenized correctly
7878
task.trainingInformation.batchSize = batchSize
79-
task.trainingInformation.maxSequenceLength = contextLength
79+
task.trainingInformation.contextLength = contextLength
8080
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')
8181
.map(text => processing.tokenize(tokenizer, text))
8282
.flat()
83-
.batchWithOverlap(config.blockSize)
83+
.batchWithOverlap(config.contextLength)
8484

8585
const preprocessedDataset = dataset
8686
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])

cli/src/train_gpt.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ async function main(): Promise<void> {
1313
maxIter: 50,
1414
evaluateEvery:50,
1515
maxEvalBatches: 10,
16-
blockSize: 16,
16+
contextLength: 16,
1717
seed
1818
}
1919

@@ -22,7 +22,7 @@ async function main(): Promise<void> {
2222
const tokenDataset = new Dataset([data])
2323
.map((text: string) => processing.tokenize(tokenizer, text))
2424
.flat()
25-
.batchWithOverlap(config.blockSize)
25+
.batchWithOverlap(config.contextLength)
2626
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
2727
.repeat()
2828
.batch(8);

discojs/src/dataset/dataset.spec.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,29 +152,29 @@ describe("dataset", () => {
152152

153153
it("batchWithOverlap yields correct batches", async () => {
154154
const expectedTokens = Range(0, 53).toList()
155-
const blockSize = 4
155+
const contextLength = 4
156156

157157
const parsed = new Dataset([expectedTokens])
158158
.flat()
159-
.batchWithOverlap(blockSize)
159+
.batchWithOverlap(contextLength)
160160

161161
// -1 because the last sequence is dropped as there is no next token label
162-
const expectedLength = Math.ceil(expectedTokens.size / blockSize) - 1
162+
const expectedLength = Math.ceil(expectedTokens.size / contextLength) - 1
163163
expect(await parsed.size()).to.equal(expectedLength);
164164

165165
// exclude the last sequence because it has been padded
166166
let sequences = List(await arrayFromAsync(parsed))
167-
// we expect the last sequence to have blockSize + 1 tokens via padding
168-
expect(sequences.last()?.size).to.equal(blockSize + 1)
167+
// we expect the last sequence to have contextLength + 1 tokens via padding
168+
expect(sequences.last()?.size).to.equal(contextLength + 1)
169169
sequences = sequences.pop()
170170
let i = 0
171171
for await (const tokens of sequences) {
172-
// each sequence has length blockSize + 1 (for the label)
172+
// each sequence has length contextLength + 1 (for the label)
173173
expect(tokens.toArray()).to.deep.equal(
174-
expectedTokens.slice(i, i + blockSize + 1).toArray()
174+
expectedTokens.slice(i, i + contextLength + 1).toArray()
175175
);
176-
// but the window should move by blockSize only
177-
i += blockSize
176+
// but the window should move by contextLength only
177+
i += contextLength
178178
}
179179
})
180180

discojs/src/default_tasks/wikitext.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ export const wikitext: TaskProvider<'text'> = {
3535
roundDuration: 2,
3636
batchSize: 8, // If set too high firefox raises a WebGL error
3737
tokenizer: 'Xenova/gpt2',
38-
maxSequenceLength: 64,
38+
contextLength: 64,
3939
tensorBackend: 'gpt'
4040
}
4141
}
4242
},
4343

4444
getModel(): Promise<Model<'text'>> {
4545
return Promise.resolve(new models.GPT({
46-
blockSize: this.getTask().trainingInformation.maxSequenceLength,
46+
contextLength: this.getTask().trainingInformation.contextLength,
4747
}))
4848
}
4949
}

discojs/src/models/gpt/config.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ type GPTModelType =
99

1010
export interface GPTConfig {
1111
lr: number
12-
blockSize: number
12+
contextLength: number
1313
vocabSize?: number
1414
modelType: GPTModelType
1515
name?: string,
@@ -39,7 +39,7 @@ export const DefaultGPTConfig: Required<GPTConfig> = {
3939
evaluate: true,
4040
maxEvalBatches: 12,
4141
evaluateEvery: 100,
42-
blockSize: 128,
42+
contextLength: 128,
4343
vocabSize: 50257,
4444
debug: false,
4545
dropout: 0.2,

discojs/src/models/gpt/gpt.spec.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ describe("gpt-tfjs", function () {
2525
maxIter: 10,
2626
evaluateEvery: 50,
2727
maxEvalBatches: 10,
28-
blockSize: 8,
28+
contextLength: 8,
2929
seed
3030
});
3131
for (let i = 0; i < 5; i++)

discojs/src/models/gpt/index.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export type GPTSerialization = {
2727
export class GPT extends Model<"text"> {
2828
private readonly model: GPTModel;
2929

30-
readonly #blockSize: number;
30+
readonly #contextLength: number;
3131
readonly #maxBatchCount: number;
3232
readonly #vocabSize: number;
3333

@@ -38,7 +38,7 @@ export class GPT extends Model<"text"> {
3838
model.compile();
3939
this.model = model;
4040

41-
this.#blockSize = partialConfig?.blockSize ?? DefaultGPTConfig.blockSize;
41+
this.#contextLength = partialConfig?.contextLength ?? DefaultGPTConfig.contextLength;
4242
this.#maxBatchCount = partialConfig?.maxIter ?? DefaultGPTConfig.maxIter;
4343
this.#vocabSize = partialConfig?.vocabSize ?? DefaultGPTConfig.vocabSize;
4444
}
@@ -157,7 +157,7 @@ export class GPT extends Model<"text"> {
157157
* Generate the next token after the input sequence.
158158
* In other words, takes an input tensor of shape (prompt length T) and returns a tensor of shape (T+1)
159159
*
160-
* @param token input tokens of shape (T,). T is truncated to the model's block size
160+
* @param token input tokens of shape (T,). T is truncated to the model's context length
161161
* @param config generation config: temperature, doSample, topk
162162
* @returns the next token predicted by the model
163163
*/
@@ -166,7 +166,7 @@ export class GPT extends Model<"text"> {
166166
config: GenerationConfig,
167167
): Promise<DataFormat.ModelEncoded["text"][1]> {
168168
// slice input tokens if longer than context length
169-
tokens = tokens.slice(-this.#blockSize);
169+
tokens = tokens.slice(-this.#contextLength);
170170

171171
const input = tf.tidy(() =>
172172
tf.tensor1d(tokens.toArray(), "int32").expandDims<tf.Tensor2D>(0),

discojs/src/models/gpt/layers.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ tf.serialization.registerClass(LogLayer)
6767

6868
type CausalSelfAttentionConfig =
6969
ConstructorParameters<typeof tf.layers.Layer>[0]
70-
& Record<'blockSize' | 'nHead' | 'nEmbd' | 'dropout' | 'nLayer' | 'seed', number>
70+
& Record<'contextLength' | 'nHead' | 'nEmbd' | 'dropout' | 'nLayer' | 'seed', number>
7171

7272
class CausalSelfAttention extends tf.layers.Layer {
7373
static readonly className = 'CausalSelfAttention'
@@ -97,7 +97,7 @@ class CausalSelfAttention extends tf.layers.Layer {
9797
// mask is a lower triangular matrix filled with 1
9898
// calling bandPart zero out the upper triangular part of the all-ones matrix
9999
// from the doc: tf.linalg.band_part(input, -1, 0) ==> Lower triangular part
100-
this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0)
100+
this.mask = tf.linalg.bandPart(tf.ones([config.contextLength, config.contextLength]), -1, 0)
101101
}
102102

103103
override build (): void {
@@ -266,15 +266,15 @@ class GELU extends tf.layers.Layer {
266266
tf.serialization.registerClass(GELU)
267267

268268
type MLPConfig = ConstructorParameters<typeof tf.layers.Layer>[0] &
269-
Required<ModelSize> & Record<'blockSize' | 'residDrop' | 'nLayer' | 'seed', number>
269+
Required<ModelSize> & Record<'contextLength' | 'residDrop' | 'nLayer' | 'seed', number>
270270

271271
function MLP(config: MLPConfig): tf.LayersModel {
272272
return tf.sequential({ layers: [
273273
tf.layers.dense({
274274
name: config.name + `.mlp.c_fc`,
275275
units: 4 * config.nEmbd,
276276
inputDim: config.nEmbd,
277-
inputShape: [config.blockSize, config.nEmbd],
277+
inputShape: [config.contextLength, config.nEmbd],
278278
kernelInitializer: tf.initializers.randomNormal({
279279
mean: 0, stddev: 0.02, seed: config.seed
280280
}),
@@ -284,7 +284,7 @@ function MLP(config: MLPConfig): tf.LayersModel {
284284
name: config.name + '.mlp.c_proj',
285285
units: config.nEmbd,
286286
inputDim: 4 * config.nEmbd,
287-
inputShape: [config.blockSize, 4 * config.nEmbd],
287+
inputShape: [config.contextLength, 4 * config.nEmbd],
288288
kernelInitializer: tf.initializers.randomNormal({
289289
mean: 0, stddev: 0.02 * Math.sqrt(2 * config.nLayer), seed: config.seed
290290
}),
@@ -306,7 +306,7 @@ type BlockConfig = CausalSelfAttentionConfig & MLPConfig & { debug: boolean }
306306
*/
307307
function TransformerBlock (conf: BlockConfig): tf.LayersModel {
308308
const config = Object.assign({ name: '.h' }, conf)
309-
const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] })
309+
const inputs = tf.input({ shape: [config.contextLength, config.nEmbd] })
310310
let x1, x2
311311
// input normalization
312312
x1 = tf.layers.layerNormalization({
@@ -469,7 +469,7 @@ export function GPTArchitecture(config: Required<GPTConfig>): tf.LayersModel {
469469
const range = new Range({}).apply(inputs)
470470
let posEmb = tf.layers.embedding({
471471
name: config.name + '.wpe',
472-
inputDim: config.blockSize,
472+
inputDim: config.contextLength,
473473
outputDim: config.nEmbd,
474474
embeddingsInitializer: tf.initializers.randomNormal({
475475
mean: 0, stddev: 0.02, seed: config.seed

discojs/src/processing/index.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ export async function preprocess<D extends DataType>(
5656
// cast as typescript doesn't reduce generic type
5757
const d = dataset as Dataset<DataFormat.Raw["text"]>;
5858
const t = task as Task<"text">;
59-
const blockSize = task.trainingInformation.maxSequenceLength
59+
const contextLength = task.trainingInformation.contextLength
6060

6161
const tokenizer = await models.getTaskTokenizer(t);
6262
return d.map(text => processing.tokenize(tokenizer, text))
6363
.flat()
64-
.batchWithOverlap(blockSize)
64+
.batchWithOverlap(contextLength)
6565
.map((tokens) => [tokens.pop(), tokens.last()]) as
6666
Dataset<DataFormat.ModelEncoded[D]>;
6767
}
@@ -97,12 +97,12 @@ export async function preprocessWithoutLabel<D extends DataType>(
9797
// cast as typescript doesn't reduce generic type
9898
const d = dataset as Dataset<DataFormat.Raw["text"]>;
9999
const t = task as Task<"text">;
100-
const blockSize = task.trainingInformation.maxSequenceLength
100+
const contextLength = task.trainingInformation.contextLength
101101
const tokenizer = await models.getTaskTokenizer(t);
102102

103103
return d.map(text => processing.tokenize(tokenizer, text))
104104
.flat()
105-
.batch(blockSize)
105+
.batch(contextLength)
106106
}
107107
}
108108
}

discojs/src/serialization/model.spec.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ describe('serialization', () => {
5151
maxIter: 10,
5252
evaluateEvery:10,
5353
maxEvalBatches: 10,
54-
blockSize: 8,
54+
contextLength: 8,
5555
}
5656
const model = new models.GPT(config)
5757

0 commit comments

Comments
 (0)