Skip to content

Commit 2d89de5

Browse files
committed
[WIP] Add support for Glm
1 parent 2c92943 commit 2d89de5

File tree

3 files changed

+73
-2
lines changed

3 files changed

+73
-2
lines changed

src/configs.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ function getNormalizedConfig(config) {
117117
break;
118118
case 'gemma':
119119
case 'gemma2':
120+
case 'glm':
120121
mapping['num_heads'] = 'num_key_value_heads';
121122
mapping['num_layers'] = 'num_hidden_layers';
122123
mapping['dim_kv'] = 'head_dim';

src/models.js

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4037,6 +4037,23 @@ export class Gemma2Model extends Gemma2PreTrainedModel { }
40374037
export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { }
40384038
//////////////////////////////////////////////////
40394039

4040+
4041+
//////////////////////////////////////////////////
4042+
// Glm models
4043+
4044+
/**
4045+
* The bare Glm Model outputting raw hidden-states without any specific head on top.
4046+
*/
4047+
export class GlmPreTrainedModel extends PreTrainedModel { }
4048+
/**
4049+
* The bare Glm Model outputting raw hidden-states without any specific head on top.
4050+
*/
4051+
export class GlmModel extends GlmPreTrainedModel { }
4052+
4053+
export class GlmForCausalLM extends GlmPreTrainedModel { }
4054+
//////////////////////////////////////////////////
4055+
4056+
40404057
//////////////////////////////////////////////////
40414058
export class OpenELMPreTrainedModel extends PreTrainedModel { }
40424059
export class OpenELMModel extends OpenELMPreTrainedModel { }
@@ -6765,6 +6782,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
67656782
['cohere', ['CohereModel', CohereModel]],
67666783
['gemma', ['GemmaModel', GemmaModel]],
67676784
['gemma2', ['Gemma2Model', Gemma2Model]],
6785+
['glm', ['GlmModel', GlmModel]],
67686786
['openelm', ['OpenELMModel', OpenELMModel]],
67696787
['qwen2', ['Qwen2Model', Qwen2Model]],
67706788
['phi', ['PhiModel', PhiModel]],
@@ -6856,6 +6874,7 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
68566874
['cohere', ['CohereForCausalLM', CohereForCausalLM]],
68576875
['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
68586876
['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]],
6877+
['glm', ['GlmForCausalLM', GlmForCausalLM]],
68596878
['openelm', ['OpenELMForCausalLM', OpenELMForCausalLM]],
68606879
['qwen2', ['Qwen2ForCausalLM', Qwen2ForCausalLM]],
68616880
['phi', ['PhiForCausalLM', PhiForCausalLM]],

tests/tiny_random.test.js

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import {
1010
BertTokenizer,
1111
T5Tokenizer,
1212
WhisperTokenizer,
13-
BartTokenizer,
1413
MarianTokenizer,
1514
PreTrainedTokenizer,
1615
AutoTokenizer,
@@ -29,6 +28,7 @@ import {
2928
CohereForCausalLM,
3029
GemmaForCausalLM,
3130
Gemma2ForCausalLM,
31+
GlmForCausalLM,
3232
OPTForCausalLM,
3333
GPTNeoXForCausalLM,
3434
GPTJForCausalLM,
@@ -1366,7 +1366,7 @@ describe("Tiny random models", () => {
13661366
});
13671367
});
13681368

1369-
describe("gemma", () => {
1369+
describe("gemma2", () => {
13701370
describe("Gemma2ForCausalLM", () => {
13711371
const model_id = "hf-internal-testing/tiny-random-Gemma2ForCausalLM";
13721372
/** @type {Gemma2ForCausalLM} */
@@ -1417,6 +1417,57 @@ describe("Tiny random models", () => {
14171417
});
14181418
});
14191419

1420+
describe("glm", () => {
1421+
describe("GlmForCausalLM", () => {
1422+
const model_id = "hf-internal-testing/tiny-random-GlmForCausalLM";
1423+
/** @type {GlmForCausalLM} */
1424+
let model;
1425+
/** @type {PreTrainedTokenizer} */
1426+
let tokenizer;
1427+
beforeAll(async () => {
1428+
model = await GlmForCausalLM.from_pretrained(model_id, {
1429+
// TODO move to config
1430+
...DEFAULT_MODEL_OPTIONS,
1431+
});
1432+
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
1433+
// tokenizer.padding_side = "left";
1434+
}, MAX_MODEL_LOAD_TIME);
1435+
1436+
it(
1437+
"batch_size=1",
1438+
async () => {
1439+
const inputs = tokenizer("hello");
1440+
const outputs = await model.generate({
1441+
...inputs,
1442+
max_length: 10,
1443+
});
1444+
expect(outputs.tolist()).toEqual([[23582n, 5797n, 38238n, 24486n, 36539n, 34489n, 6948n, 34489n, 6948n, 34489n]]);
1445+
},
1446+
MAX_TEST_EXECUTION_TIME,
1447+
);
1448+
1449+
it(
1450+
"batch_size>1",
1451+
async () => {
1452+
const inputs = tokenizer(["hello", "hello world"], { padding: true });
1453+
const outputs = await model.generate({
1454+
...inputs,
1455+
max_length: 10,
1456+
});
1457+
expect(outputs.tolist()).toEqual([
1458+
[59246n, 23582n, 5797n, 38238n, 24486n, 36539n, 34489n, 6948n, 34489n, 6948n],
1459+
[23582n, 2901n, 39936n, 25036n, 55411n, 10337n, 3424n, 39183n, 30430n, 37285n]
1460+
]);
1461+
},
1462+
MAX_TEST_EXECUTION_TIME,
1463+
);
1464+
1465+
afterAll(async () => {
1466+
await model?.dispose();
1467+
}, MAX_MODEL_DISPOSE_TIME);
1468+
});
1469+
});
1470+
14201471
describe("gpt_neo", () => {
14211472
describe("GPTNeoForCausalLM", () => {
14221473
const model_id = "hf-internal-testing/tiny-random-GPTNeoForCausalLM";

0 commit comments

Comments
 (0)