Skip to content

Commit c89299f

Browse files
authored
Add CLIPTextModel and CLIPVisionModel (#829)
* Add `CLIPTextModel` and `CLIPVisionModel` * Fix jinja2 version for tests
1 parent 3072008 commit c89299f

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/models.js

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3099,6 +3099,18 @@ export class CLIPPreTrainedModel extends PreTrainedModel { }
30993099
*/
31003100
export class CLIPModel extends CLIPPreTrainedModel { }
31013101

3102+
/**
3103+
* The text model from CLIP without any head or projection on top.
3104+
*/
3105+
export class CLIPTextModel extends CLIPPreTrainedModel {
3106+
/** @type {PreTrainedModel.from_pretrained} */
3107+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3108+
// Update default model file name if not provided
3109+
options.model_file_name ??= 'text_model';
3110+
return super.from_pretrained(pretrained_model_name_or_path, options);
3111+
}
3112+
}
3113+
31023114
/**
31033115
* CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output)
31043116
*
@@ -3126,7 +3138,6 @@ export class CLIPModel extends CLIPPreTrainedModel { }
31263138
* ```
31273139
*/
31283140
export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
3129-
31303141
/** @type {PreTrainedModel.from_pretrained} */
31313142
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
31323143
// Update default model file name if not provided
@@ -3135,6 +3146,18 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
31353146
}
31363147
}
31373148

3149+
/**
3150+
* The vision model from CLIP without any head or projection on top.
3151+
*/
3152+
export class CLIPVisionModel extends CLIPPreTrainedModel {
3153+
/** @type {PreTrainedModel.from_pretrained} */
3154+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3155+
// Update default model file name if not provided
3156+
options.model_file_name ??= 'vision_model';
3157+
return super.from_pretrained(pretrained_model_name_or_path, options);
3158+
}
3159+
}
3160+
31383161
/**
31393162
* CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output)
31403163
*

tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ sacremoses==0.0.53
33
sentencepiece==0.1.99
44
protobuf==4.24.3
55
rjieba==0.1.11
6+
jinja2==3.1.0

0 commit comments

Comments
 (0)