Skip to content

Commit d040e81

Browse files
committed
[WIP] Add support for deepseek-ai/Janus-1.3B
1 parent 03f6662 commit d040e81

File tree

6 files changed

+324
-6
lines changed

6 files changed

+324
-6
lines changed

src/configs.js

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ function getNormalizedConfig(config) {
6969
case 'musicgen':
7070
init_normalized_config = getNormalizedConfig(config.decoder);
7171
break;
72+
case 'multi_modality':
73+
init_normalized_config = getNormalizedConfig(config.language_config);
74+
break;
7275

7376
// Decoder-only models
7477
case 'gpt2':
@@ -216,14 +219,12 @@ function getNormalizedConfig(config) {
216219
*/
217220
export function getKeyValueShapes(config, {
218221
prefix = 'past_key_values',
222+
batch_size=1,
219223
} = {}) {
220224
/** @type {Record<string, number[]>} */
221225
const decoderFeeds = {};
222226
const normalized_config = config.normalized_config;
223227

224-
// TODO support batches (i.e., batch_size > 1)
225-
const batch_size = 1;
226-
227228
if (normalized_config.is_encoder_decoder && (
228229
'num_encoder_heads' in normalized_config && 'num_decoder_heads' in normalized_config
229230
)) {

src/models.js

Lines changed: 183 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ import {
6161
} from './utils/generic.js';
6262

6363
import {
64-
isIntegralNumber,
6564
mergeArrays,
6665
pick,
6766
} from './utils/core.js';
@@ -99,6 +98,7 @@ import {
9998

10099
import {
101100
cat,
101+
full,
102102
full_like,
103103
mean,
104104
ones,
@@ -108,6 +108,7 @@ import {
108108
Tensor,
109109
zeros_like,
110110
} from './utils/tensor.js';
111+
import { RawImage } from './utils/image.js';
111112

112113
import { dynamic_time_warping, medianFilter } from './utils/maths.js';
113114
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
@@ -128,6 +129,7 @@ const MODEL_TYPES = {
128129
MaskGeneration: 5,
129130
ImageTextToText: 6,
130131
Musicgen: 7,
132+
MultiModality: 8,
131133
}
132134
//////////////////////////////////////////////////
133135

@@ -386,7 +388,7 @@ async function sessionRun(session, inputs) {
386388
} catch (e) {
387389
// This usually occurs when the inputs are of the wrong type.
388390
console.error(`An error occurred during model execution: "${e}".`);
389-
console.error('Inputs given to model:', checkedInputs);
391+
console.error('Inputs given to model:', checkedInputs)
390392
throw e;
391393
}
392394
}
@@ -716,6 +718,52 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
716718
}
717719
}
718720

721+
function multimodality_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
722+
const has_past_key_values = !!model_inputs.past_key_values;
723+
724+
if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
725+
if (has_past_key_values) {
726+
model_inputs.input_ids = cat([
727+
model_inputs.input_ids,
728+
model_inputs.input_ids,
729+
], 0)
730+
// NOTE: attention_mask handled in generation
731+
} else {
732+
model_inputs.input_ids = cat([
733+
model_inputs.input_ids,
734+
full_like(model_inputs.input_ids, BigInt(generation_config.pad_token_id)),
735+
], 0);
736+
model_inputs.attention_mask = cat([
737+
model_inputs.attention_mask,
738+
full_like(model_inputs.attention_mask, 0n),
739+
], 0);
740+
}
741+
}
742+
743+
if (has_past_key_values || !model_inputs.pixel_values) {
744+
model_inputs.pixel_values = full([0, 0, 3, 384, 384], 1.0);
745+
}
746+
747+
if (has_past_key_values) {
748+
const num_img_tokens = 0;
749+
const num_text_tokens = 1;
750+
const has_image = num_img_tokens > 0 ? 1 : 0;
751+
752+
const batch_size = 1;
753+
model_inputs.images_seq_mask = new Tensor(
754+
'bool',
755+
new Array(num_img_tokens + num_text_tokens).fill(true).fill(false, 0, num_text_tokens),
756+
[batch_size, num_img_tokens + num_text_tokens],
757+
);
758+
model_inputs.images_emb_mask = new Tensor(
759+
'bool',
760+
new Array(num_img_tokens).fill(!!has_image),
761+
[batch_size, 1, num_img_tokens],
762+
);
763+
}
764+
return model_inputs;
765+
}
766+
719767
//////////////////////////////////////////////////
720768

721769
//////////////////////////////////////////////////
@@ -769,6 +817,11 @@ export class PreTrainedModel extends Callable {
769817
this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
770818
break;
771819

820+
case MODEL_TYPES.MultiModality:
821+
this.can_generate = true;
822+
this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation;
823+
break;
824+
772825
default:
773826
// should be MODEL_TYPES.EncoderOnly
774827
this._forward = encoderForward;
@@ -912,6 +965,21 @@ export class PreTrainedModel extends Callable {
912965
}, options),
913966
]);
914967

968+
} else if (modelType === MODEL_TYPES.MultiModality) {
969+
info = await Promise.all([
970+
constructSessions(pretrained_model_name_or_path, {
971+
prepare_inputs_embeds: 'prepare_inputs_embeds',
972+
model: 'language_model',
973+
lm_head: 'lm_head',
974+
gen_head: 'gen_head',
975+
gen_img_embeds: 'gen_img_embeds',
976+
image_decode: 'image_decode',
977+
}, options),
978+
getOptionalConfigs(pretrained_model_name_or_path, {
979+
generation_config: 'generation_config.json',
980+
}, options),
981+
]);
982+
915983
} else { // should be MODEL_TYPES.EncoderOnly
916984
if (modelType !== MODEL_TYPES.EncoderOnly) {
917985
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
@@ -1658,7 +1726,8 @@ export class PreTrainedModel extends Callable {
16581726
const dtype = session?.config?.kv_cache_dtype ?? 'float32';
16591727
const empty = (dtype === 'float16') ? new Uint16Array() : [];
16601728

1661-
const shapes = getKeyValueShapes(this.config);
1729+
const batch_size = decoderFeeds[this.main_input_name].dims[0];
1730+
const shapes = getKeyValueShapes(this.config, { batch_size });
16621731

16631732
for (const name in shapes) {
16641733
decoderFeeds[name] = new Tensor(dtype, empty, shapes[name]);
@@ -5954,6 +6023,111 @@ export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel
59546023

59556024
//////////////////////////////////////////////////
59566025

6026+
export class MultiModalityPreTrainedModel extends PreTrainedModel { }
6027+
export class MultiModalityCausalLM extends MultiModalityPreTrainedModel {
6028+
forward_params = [
6029+
// prepare_inputs_embeds
6030+
'input_ids',
6031+
'pixel_values',
6032+
'images_seq_mask',
6033+
'images_emb_mask',
6034+
6035+
// language_model
6036+
'attention_mask',
6037+
'position_ids',
6038+
'past_key_values',
6039+
];
6040+
6041+
constructor(...args) {
6042+
super(...args);
6043+
6044+
// State-based approach to switch out which heads to use during generation
6045+
this._generation_mode = 'text';
6046+
}
6047+
6048+
async forward(model_inputs) {
6049+
const mode = this._generation_mode ?? 'text';
6050+
6051+
// TODO support re-using PKVs for input_ids.dims[1] !== 1
6052+
// if (model_inputs.past_key_values) {
6053+
// // && model_inputs.input_ids.dims[1] === 1
6054+
// }
6055+
6056+
let output_1;
6057+
if (mode === 'text' || !model_inputs.past_key_values) {
6058+
const session = this.sessions['prepare_inputs_embeds'];
6059+
const prep_inputs = pick(model_inputs, session.inputNames);
6060+
output_1 = await sessionRun(session, prep_inputs);
6061+
} else {
6062+
const session = this.sessions['gen_img_embeds'];
6063+
const prep_inputs = pick({
6064+
image_ids: model_inputs.input_ids,
6065+
}, session.inputNames);
6066+
output_1 = await sessionRun(session, prep_inputs);
6067+
}
6068+
6069+
const input_2 = { ...model_inputs, ...output_1 }
6070+
const output_2 = await decoderForward(this, input_2);
6071+
6072+
const head = this.sessions[
6073+
mode === 'text'
6074+
? 'lm_head'
6075+
: 'gen_head'
6076+
];
6077+
if (!head) {
6078+
throw new Error(`Unable to find "${head}" generation head`);
6079+
}
6080+
6081+
const output_3 = await sessionRun(head, pick(output_2, head.inputNames))
6082+
6083+
return {
6084+
...output_1,
6085+
...output_2,
6086+
...output_3,
6087+
};
6088+
}
6089+
6090+
/**
6091+
* @param {import('./generation/parameters.js').GenerationFunctionParameters} options
6092+
*/
6093+
async generate(options) {
6094+
this._generation_mode = 'text';
6095+
return super.generate(options);
6096+
}
6097+
6098+
/**
6099+
* @param {import('./generation/parameters.js').GenerationFunctionParameters} options
6100+
*/
6101+
async generate_images(options) {
6102+
this._generation_mode = 'image';
6103+
6104+
const start_num_tokens = (options.inputs ?? options[this.main_input_name]).dims[1];
6105+
const all_tokens = await super.generate(options);
6106+
6107+
const generated_tokens = (/** @type {Tensor} */(all_tokens)).slice(null, [start_num_tokens, null])
6108+
6109+
const image_decode = this.sessions['image_decode'];
6110+
const { decoded_image } = await sessionRun(image_decode, {
6111+
generated_tokens,
6112+
});
6113+
6114+
// Equivalent to `np.clip((dec + 1) / 2 * 255, 0, 255)`
6115+
const clamped = decoded_image
6116+
.add_(1)
6117+
.mul_(255 / 2)
6118+
.clamp_(0, 255)
6119+
.to('uint8');
6120+
6121+
// Return as a list of images
6122+
const images = [];
6123+
for (const tensor of clamped) {
6124+
const img = RawImage.fromTensor(tensor);
6125+
images.push(img);
6126+
}
6127+
return images;
6128+
}
6129+
}
6130+
59576131
//////////////////////////////////////////////////
59586132
// AutoModels, used to simplify construction of PreTrainedModels
59596133
// (uses config to instantiate correct class)
@@ -6232,6 +6406,11 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
62326406
['stablelm', ['StableLmForCausalLM', StableLmForCausalLM]],
62336407
]);
62346408

6409+
const MODEL_FOR_MULTIMODALITY_MAPPING_NAMES = new Map([
6410+
['multi_modality', ['MultiModalityCausalLM', MultiModalityCausalLM]],
6411+
]);
6412+
6413+
62356414
const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
62366415
['bert', ['BertForMaskedLM', BertForMaskedLM]],
62376416
['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]],
@@ -6404,6 +6583,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
64046583
[MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
64056584
[MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
64066585
[MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
6586+
[MODEL_FOR_MULTIMODALITY_MAPPING_NAMES, MODEL_TYPES.MultiModality],
64076587
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
64086588
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
64096589
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],

src/models/image_processors.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export * from './donut/image_processing_donut.js'
1010
export * from './dpt/image_processing_dpt.js'
1111
export * from './efficientnet/image_processing_efficientnet.js'
1212
export * from './glpn/image_processing_glpn.js'
13+
export * from './janus/image_processing_janus.js'
1314
export * from './jina_clip/image_processing_jina_clip.js'
1415
export * from './mask2former/image_processing_mask2former.js'
1516
export * from './maskformer/image_processing_maskformer.js'
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
2+
import {
3+
ImageProcessor,
4+
} from "../../base/image_processors_utils.js";
5+
6+
export class VLMImageProcessor extends ImageProcessor {
7+
constructor(config) {
8+
super({
9+
do_pad: true,
10+
pad_size: {
11+
width: config.image_size,
12+
height: config.image_size,
13+
},
14+
...config,
15+
});
16+
this.constant_values = this.config.background_color.map(x => x * this.rescale_factor)
17+
}
18+
19+
pad_image(pixelData, imgDims, padSize, options) {
20+
return super.pad_image(pixelData, imgDims, padSize, {
21+
constant_values: this.constant_values,
22+
center: true,
23+
...options,
24+
});
25+
}
26+
}

0 commit comments

Comments
 (0)