Skip to content

Commit a4057ab

Browse files
HaFredSamitHuang
andauthored
Janus-Pro Mixed-task SFT Training (mindspore-lab#911)
* fixe unnessary flake8 check about the colon space * fix typo and refactor readme * fix mul issues during conversion tests * fix mul issues during conversion tests * fix redundancy * sv3d conversion test * pdat * update linting * rm flake8 change * update sampling script: rm unnecessary device_id arg * update sampling script: rm unnecessary device_id arg * init janus * update * add vq16 * init januspro models ph * vq load ckpt * init commit 4 siglip * update siglip * same * fix interpolate * test vq fp32 mre<1e-5 * fix norm, decode image ok in bf16 * update siglip * siglip testing, mae around 3e-2 * for merge * same * add VLChatProcessor and test ok * add chat process test * use pil resize * support loading model.bin directly from `from_pretrained`, vlm under test * add mlp projector (und. vision aligner) * vlm under test: LlamaForCausalLM discprency with the hf * add vqa test * fix conflicts * get correct GELU setup * LlamaModel return_dict=True * Merge from HaFred/mindone/janus * gen_inf done, w/o kv cache, siglip precision needed to be aligned * tqdm better vis * vqa infer runnable, generated answer not compelete * reshape tensor * fix length and graph mode for VQA, answer is complete and better aligned * setup static cache, but self(input_embeds) output tokens still not correct * fix t2v w/o kv cache, gen ok * fix dim * housekeeping t2i * fix temperature=0 bug * add readme and slightly refactor inference * Update README.md * rm file * add file * support `use_cache==True` with an explicit init of static cache in tuples, but the speed is even slower than `use_cache==False` * add gradio and pyproject.toml * add file for gradio * update readme and fix gradio * Update README.md * small fix * fix bug in graph mode vqa * rm formula example * patching others * support loading sharded ckpt, undo saving .safetensors * add vit block test * rm attrdict and use addict for py3.10 * use ms.Tensor instead of mint.Tensor for compat * fix vanilla attn in siglip * use mint GELU in sigLIP * fix above * add vit block test * align the naming of vqa_inference * update readme * add throughput calc * t2i support graph mode w/ dynamic shape input * make t2i generation printing time traces, and support graph mode now with ops.multinomial * refine multinomal * add multinomial test * padded kv cache for graph mode t2i * auto select multinomial * fix transformers FA dtype * fix merge conflict * fix vision encode precision, formula parsing ok * revert llama min_dtype to eliminate en/zh mix in vqa * for t2i llama_model forward, init a static cache outside the comp graph * fp32, this is for the prev commit * fix typo * fix typo * fix typo * llamamodel.construct() handling dynamic shape in graph mode * fix ms2.5 pynative error * acc infer by mint.topk * transformers use mint.multinomial for ms2.5 * better precision aligned (mint.nn.Conv2d), requires ms2.5 * update readme * support t2i graph mode kv cache, bringing into a speed boosting * fix an initial negligence * t2i graph mode only `set_inputs` for no cache cases, coz with cache it's slower to `set_inputs` for dynamic shape * fix typo and default * logging the token/s for vqa excluding graph compilation time, * include performance tables * update the correct perforrmance for the graph mode-compatiable code * revert: no more .bin loading * update readme, ms_compiler_cache enabled * lint files * lint files * update main page readme * lint files * lint files * linting * update both inference files for faster model init * fix linting * update readme * merge samit code * match with new np * dev loss * mixed-task training `bs=6` ok, `bs=8` oom * mixed-task sft training, support slice dataset or weightrandsamp dataset * mixed-task sft training, support slice dataset or weightrandsamp dataset * lint * update comment * support graph mode mixed-data sft, with `use_value_and_grad==False` * support graph mode mixed-data sft, with `use_value_and_grad==False` * rm redundant code in siglip * rn no grad trunc normal used in siglip * lint * graph mode sft launch script * lint * lint * lint * lint * rm redundancy * return precommit * update readme and set vlm.construct() default as pynative version for any-task training * fix readme * update discrepancy * train mode selection * update modeling_vlm which already works under graph to further support pynative at the same time * fix ci * fix ci * update for graphmode mixed-task sft * update req * update datasets * update training script to support all testing examples in training.md * fix lint * fix lint * refactor * update readme, train script for better readability * update readme, train script for better readability * doc: put into the right order for doing single task sft graph mode patching --------- Co-authored-by: SamitHuang <285365963@qq.com>
1 parent 6df5641 commit a4057ab

21 files changed

+858
-256
lines changed

examples/janus/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
<!-- 🤗 Online Demo (<a href="https://huggingface.co/spaces/deepseek-ai/Janus-Pro-7B"><b>Janus-Pro-7B</b></a>, <a href="https://huggingface.co/spaces/deepseek-ai/Janus-1.3B"><b>Janus</b></a>, <a href="https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B"><b>JanusFlow</b></a>) -->
2828
</p>
2929

30-
We provide an efficient MindSpore implementation of [JanusPro](https://github.com/deepseek-ai/Janus). This repository is built on the models and code released by DeepSeek. We are grateful for their exceptional work and generous contribution to open source.
30+
We provide an efficient MindSpore implementation of [Janus-Pro](https://github.com/deepseek-ai/Janus). This repository is built on the models and code released by DeepSeek. We are grateful for their exceptional work and generous contribution to open source.
3131

3232

3333
## News
3434

35-
**2025.03.12**: We have reproduced the multi-modal training pipelines referring to the JanusPro [paper](https://github.com/deepseek-ai/Janus), see [docs/training.md](docs/training.md).
35+
**2025.03.12**: We have reproduced the multi-modal training pipelines referring to the Janus-Pro [paper](https://github.com/deepseek-ai/Janus), see [docs/training.md](docs/training.md).
3636

3737
**2025.02.10**: MindSpore implementation of Janus-Pro is released, supporting both multimodal understanding and visual generation on Ascend NPU.
3838

@@ -51,7 +51,7 @@ Generation with Data and Model Scaling</b></a>
5151

5252
## 2. Model Download
5353

54-
JanusPro is available to the public to support a broader and more diverse range of research within both academic and commercial communities.
54+
Janus-Pro is available to the public to support a broader and more diverse range of research within both academic and commercial communities.
5555
Please note that the use of this model is subject to the terms outlined in [License section](#5-license). Commercial usage is
5656
permitted under these terms.
5757

examples/janus/docs/training.md

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# JanusPro Training
1+
# Janus-Pro Training
22

33
## Requirements
44

@@ -21,12 +21,32 @@ huggingface-cli download jasonhuang23/artwork --repo-type dataset --local-dir d
2121
huggingface-cli download rbojja/medical-vqa --repo-type dataset --local-dir datasets/medical-vqa
2222
```
2323

24-
## Run Training
24+
Before launching sft training with the scripts under [../scripts/](../scripts/), we need to setup the meta env var `YOUR_DATA_PATH` and `YOUR_DOWNLOADED_JANUS_CKPT_PATH` for each script.
25+
26+
## Run Training for Single Task
27+
After setting up paths as above, you are good to go.
28+
29+
- Multimodal Understanding Task (VQA)
30+
31+
```shell
32+
bash scripts/run_sft_vqa.sh
33+
```
2534

2635
- Text Generation Task
2736

2837
```shell
29-
bash scripts/run_sft_text.sh
38+
bash scripts/run_sft_text.sh # if no manual patching, by default it should be changed into pynative
39+
```
40+
41+
Patching `janus/models/modeling_vlm.py`: **Single task for pure text**
42+
```diff
43+
# @ L428
44+
-- def construct(
45+
++ # def construct( # just comment the whole function out
46+
47+
# @ L476
48+
-- def construct_graph_single_task(
49+
++ def construct(
3050
```
3151

3252
- Text-to-Image Generation Task (T2I)
@@ -35,22 +55,50 @@ bash scripts/run_sft_text.sh
3555
bash scripts/run_sft_t2i.sh
3656
```
3757

38-
- Multimodal Understanding Task (VQA)
58+
The default training stage is stage 3, that is, all modules are trainable except for VQ16 for image token decoding. To switch to other stage, you can modify the `--stage` argument in the training script.
59+
60+
For more detailed arguments, please run `python train.py -h`.
61+
62+
### Multi-task Supervised Fune-tuning (Mixed-SFT)
3963

4064
```shell
41-
bash scripts/run_sft_vqa.sh
65+
bash scripts/run_sft_mixed_graph.sh
4266
```
4367

44-
The default training stage is stage 3, that is, all modules are trainable except for VQ16 for image token decoding. To switch to other stage, you can modify the `--stage` argument in the training script.
68+
We also implemented **a stage-3 SFT for medical data aiming for building a radiology expert model**. The datasets can be retrieved from huggingface with from the following repos.
4569

46-
For more detailed arguments, please run `python train.py -h`.
70+
| | #Data Samples | HuggingFace Source |
71+
| --- | --- | --- |
72+
| VQA | 100 | robojja/medical-vqa |
73+
| pure-text | 20 | qiaojin/PubmeQA |
74+
| T2I | 80 | mdwiratathya/ROCO-radiology |
4775

76+
#### Graph Mode SFT Training for Mixed Tasks
4877

49-
- Multi-task Fune-tuning
78+
> [!NOTE]
79+
> We achieve higher training throughput by enabling graph mode compute. However, to do that we need to predefine a compute graph for the vlm for each of the task out of three in total, as for each task, the vlm takes different types of input arg pairs.
80+
>
81+
> To run `scripts/run_sft_mixed_graph.sh`, simply go into `janus/models/modeling_vlm.py`, and patch `construct_*()` into `construct()` as follows.
82+
```diff
83+
# @ L428
84+
-- def construct(
85+
++ # def construct( # just comment the whole function out
5086

51-
Comming soon
87+
# @ L570
88+
-- def construct_graph_mixed_task(
89+
++ def construct(
90+
```
5291

92+
#### Pynative Mode SFT Training for Mixed Tasks
93+
```diff
94+
# @ L428
95+
-- def construct(
96+
++ # def construct( # just comment the whole function out
5397

98+
# @ L516
99+
-- def construct_pynative_mixed_task(
100+
++ def construct(
101+
```
54102

55103
## Performance
56104

@@ -64,3 +112,10 @@ Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pyn
64112
| Janus-Pro-7B | T2I | 1 | 384x384 | 1024 | 1 | 0.49 |
65113
| Janus-Pro-7B | VQA | 1 | 384x384 | 1024 | 1 | 0.66 |
66114
| Janus-Pro-7B | Text | 1 | n.a. | 512 | 1 | 0.53 |
115+
116+
For mixed-SFT:
117+
118+
| model | task | ms_mode | # card(s) | image size | max_length | batch size | step time (s/step)|
119+
|:-:|:--:| :-:|:-:|:-:|:-:|:-:|:-:|
120+
| Janus-Pro-1B | mixed | pynative | 1 | 384x384 | 1024 | 6 | 3.05 |
121+
| Janus-Pro-1B | mixed | graph | 1 | 384x384 | 1024 | 6 | 2.36 |

examples/janus/janus/__init__.py

Whitespace-only changes.

examples/janus/janus/models/modeling_vlm.py

Lines changed: 146 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def gen_with_loss(
289289
attention_mask: shape (bs seq_len), where 1 for valid input seq, 0 for padded seq
290290
image_seq_mask: 1 - image tokens (exclude BOI and EOI)
291291
pixel_values: images resized to (384, 384), shape (bs n_images 3 h w)
292-
image_tokens: image tokens encoded and quantized by VQ16, shape (bs n_images per_img_seq_len)
292+
image_tokens: deprecated, image tokens encoded and quantized by VQ16, shape (bs n_images per_img_seq_len)
293293
294294
Note: pre-compute VQ encoded tokens for efficiency
295295
"""
@@ -321,12 +321,14 @@ def gen_with_loss(
321321
# these reshape ops is to solve the wierd error in InferShape in MS
322322
inputs_embeds = inputs_embeds.reshape(-1, D) # (B, S, D) -> (B * S, D)
323323
image_seq_mask = image_seq_mask.reshape(-1) # (B, S) -> (B * S)
324-
image_embeds = image_embeds.reshape(-1, D) # (B, S, D) -> (B * S, D)
324+
image_embeds = image_embeds.reshape(-1, D) # (B, T, D) -> (B * T, D)
325325

326-
# another way: inputs_embeds = inputs_embeds * (1 - image_seq_mask) + ops.stop_gradient(image_embeds) * image_seq_mask.to(ms.int)
327-
# FIXME: this inplace op doens't support in graph mode
328-
# FIXME: check whether need to bprop the graident from image_embedding to LlamModel.embed_tokens (nn.Embedding)
329-
inputs_embeds[image_seq_mask] = ops.stop_gradient(image_embeds)
326+
# FIXME ms2.5.0 graph mode does not support _tensor_setitem_by_bool_tensor_with_tensor().
327+
# Workaround: _tensor_setitem_by_int_tensor_with_tensor()
328+
_image_seq_mask = image_seq_mask.nonzero().squeeze()
329+
# above tensor.squeeze() does not work under pynatvie dunno why...
330+
# _image_seq_mask = image_seq_mask.nonzero().reshape(-1) # workaround for both pynative & graph: force flatten
331+
inputs_embeds[_image_seq_mask] = image_embeds
330332

331333
inputs_embeds = inputs_embeds.reshape(B, S, D)
332334
image_seq_mask = image_seq_mask.reshape(B, S)
@@ -342,7 +344,7 @@ def gen_with_loss(
342344
# 4. gen head projection
343345
# since Janus use decouple heads for image and text, only image seq is meaningful input to gen head. mask before linear should save compute cost.
344346
# TODO: tbc influence on gradient ?
345-
image_hidden_states = hidden_states[image_seq_mask].reshape(B, -1, D)
347+
image_hidden_states = hidden_states[image_seq_mask].reshape(B, T, D)
346348
logits = self.gen_head(image_hidden_states)
347349

348350
# 5. loss compute
@@ -404,13 +406,13 @@ def und_with_loss(
404406
# these reshape ops is to solve the wierd error in InferShape in MS
405407
inputs_embeds = inputs_embeds.reshape(-1, D) # (B, S, D) -> (B * S, D)
406408
image_seq_mask = image_seq_mask.reshape(-1) # (B, S) -> (B * S)
407-
image_embeds = image_embeds.reshape(-1, D) # (B, S, D) -> (B * S, D)
409+
image_embeds = image_embeds.reshape(-1, D) # (B, T, D) -> (B * T, D)
408410

409-
# FIXME: fix as gen_with_loss to support graph mode
410-
inputs_embeds[image_seq_mask] = image_embeds # ops.stop_gradient(image_embeds)
411+
# FIXME same workaround as above, for the ms2.5.0 graph mode constraint
412+
image_seq_mask = image_seq_mask.nonzero().squeeze()
413+
inputs_embeds[image_seq_mask] = image_embeds
411414

412415
inputs_embeds = inputs_embeds.reshape(B, S, D)
413-
image_seq_mask = image_seq_mask.reshape(B, S)
414416

415417
# 3. LlamaForCausalLM forward with loss
416418
output = self.language_model(
@@ -420,7 +422,6 @@ def und_with_loss(
420422
return_dict=False,
421423
)
422424
loss = output[0]
423-
# logit = output[1]
424425

425426
return loss
426427

@@ -435,7 +436,7 @@ def construct(
435436
image_tokens: Optional[Tensor] = None,
436437
):
437438
r"""
438-
Added for training, and only used in training!
439+
Implemented for single task pynative training. Support branch control for a SINGLE task in task_type.
439440
Args:
440441
input_ids: input sequence of tokens, shape (bs seq_len). see transformers docstring for details
441442
task_type: shape (bs,), 0 - pure text, 1 - vqa, 2 - t2i
@@ -472,6 +473,138 @@ def construct(
472473

473474
return loss
474475

476+
def construct_graph_single_task(
477+
self,
478+
task_type: Tensor = None,
479+
input_ids: Tensor = None,
480+
labels: Optional[Tensor] = None,
481+
attention_mask: Optional[Tensor] = None,
482+
image_seq_mask: Optional[Tensor] = None,
483+
pixel_values: Optional[Tensor] = None,
484+
image_tokens: Optional[Tensor] = None,
485+
):
486+
"""
487+
Implemented for single task graph mode sft.
488+
As task_type tensor cannot be used for branch control, thus this method implements per task forward.
489+
"""
490+
491+
# text
492+
loss = self.language_model(
493+
input_ids=input_ids,
494+
attention_mask=attention_mask,
495+
labels=labels,
496+
)[0]
497+
# # vqa
498+
# loss = self.und_with_loss(
499+
# input_ids=input_ids,
500+
# attention_mask=attention_mask,
501+
# labels=labels,
502+
# image_seq_mask=image_seq_mask,
503+
# pixel_values=pixel_values,
504+
# )
505+
# # t2i
506+
# loss = self.gen_with_loss(
507+
# input_ids=input_ids,
508+
# attention_mask=attention_mask,
509+
# image_seq_mask=image_seq_mask,
510+
# pixel_values=pixel_values,
511+
# image_tokens=image_tokens,
512+
# # labels,
513+
# )
514+
return loss
515+
516+
def construct_pynative_mixed_task(
517+
self,
518+
task_type: Tensor = None,
519+
input_ids: Tensor = None,
520+
labels: Optional[Tensor] = None,
521+
attention_mask: Optional[Tensor] = None,
522+
image_seq_mask: Optional[Tensor] = None,
523+
pixel_values: Optional[Tensor] = None,
524+
image_tokens: Optional[Tensor] = None,
525+
):
526+
"""Implemented for mixed-task pynative mode sft. Support branch control for mixed task. Go with this if you need MULTIPLE task sft."""
527+
528+
losses = []
529+
for ti, task in enumerate(task_type):
530+
_input_ids = input_ids[ti][None, ...]
531+
_labels = labels[ti][None, ...]
532+
_attention_mask = attention_mask[ti][None, ...]
533+
_image_seq_mask = image_seq_mask[ti][None, ...]
534+
_pixel_values = pixel_values[ti][None, ...]
535+
if task == 0:
536+
# mm understand
537+
loss = self.und_with_loss(
538+
input_ids=_input_ids,
539+
attention_mask=_attention_mask,
540+
labels=_labels,
541+
image_seq_mask=_image_seq_mask,
542+
pixel_values=_pixel_values,
543+
)
544+
elif task == 1:
545+
# text
546+
loss = self.language_model(
547+
input_ids=_input_ids,
548+
attention_mask=_attention_mask,
549+
labels=_labels,
550+
)[0]
551+
elif task == 2:
552+
# t2i
553+
loss = self.gen_with_loss(
554+
input_ids=_input_ids,
555+
attention_mask=_attention_mask,
556+
image_seq_mask=_image_seq_mask,
557+
pixel_values=_pixel_values,
558+
# image_tokens=image_tokens,
559+
# labels,
560+
)
561+
else:
562+
raise ValueError(f"task type should be one of [0, 1, 2], but get {task_type}")
563+
564+
losses.append(loss)
565+
566+
loss = mint.mean(mint.stack(losses))
567+
568+
return loss
569+
570+
def construct_graph_mixed_task(
571+
self,
572+
task_type: Tensor = None,
573+
input_ids: Tensor = None,
574+
labels: Optional[Tensor] = None,
575+
attention_mask: Optional[Tensor] = None,
576+
image_seq_mask: Optional[Tensor] = None,
577+
pixel_values: Optional[Tensor] = None,
578+
):
579+
"""Implemented for mixed-task pynative mode sft. Support branch control for mixed task under graph mode."""
580+
581+
is_vqa_index = (task_type == 0).nonzero().squeeze(-1)
582+
loss_vqa = self.und_with_loss(
583+
input_ids=input_ids[is_vqa_index],
584+
attention_mask=attention_mask[is_vqa_index],
585+
labels=labels[is_vqa_index],
586+
image_seq_mask=image_seq_mask[is_vqa_index],
587+
pixel_values=pixel_values[is_vqa_index],
588+
)
589+
590+
is_text_index = (task_type == 1).nonzero().squeeze(-1)
591+
loss_text = self.language_model(
592+
input_ids=input_ids[is_text_index],
593+
attention_mask=attention_mask[is_text_index],
594+
labels=labels[is_text_index],
595+
)[0]
596+
597+
is_t2i_index = (task_type == 2).nonzero().squeeze(-1)
598+
loss_t2i = self.gen_with_loss(
599+
input_ids=input_ids[is_t2i_index],
600+
attention_mask=attention_mask[is_t2i_index],
601+
image_seq_mask=image_seq_mask[is_t2i_index],
602+
pixel_values=pixel_values[is_t2i_index],
603+
)
604+
605+
loss = (loss_vqa + loss_text + loss_t2i) / 3
606+
return loss
607+
475608

476609
AutoConfig.register("vision", VisionConfig)
477610
AutoConfig.register("aligner", AlignerConfig)

0 commit comments

Comments
 (0)