Skip to content

Commit 8130a3f

Browse files
jilongWZePan110ramesh-katkuriRamesh Katkuripre-commit-ci[bot]
authored
[Xtune]Update CNclip & Qwen2.5-VL doc (#1958)
* Fix build issues (#1937) * Fix build issues Add docling in requirements.in Change pathway version to fix dependency conflict. Signed-off-by: ZePan110 <ze.pan@intel.com> --------- Signed-off-by: ZePan110 <ze.pan@intel.com> Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * Add Arbitration Post-Hearing Component with LLM-Based Entity Extraction (#1938) * initial commit for arbitratory micro service * test cases added * test cases added * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * renamed test file as per review comment * updated path of SCRIPT_DIR in to resolve microservice test * resolved comments wrt license header and env configs in compose file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed unused MODEL env variable Signed-off-by: Author Name <c.noeljaymon@zensar.com> * removed space and added sign off Signed-off-by: Noel Jaymon <c.noeljaymon@zensar.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolved ci issues for compose file name and readme reference paths Signed-off-by: Noel Jaymon <c.noeljaymon@zensar.com> * added arb_post_hearing_assistant-compose.yaml file in .github folder Signed-off-by: Noel Jaymon <c.noeljaymon@zensar.com> * deleted unused file redis-values.yaml Signed-off-by: Noel Jaymon <c.noeljaymon@zensar.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * made service and image name same for arb_post_hearing_assistant-compose.yaml in .github folder Signed-off-by: Noel Jaymon <c.noeljaymon@zensar.com> * DCO remediation: adding missing Signed-off-by lines Signed-off-by: Ramesh <katkuri.ramesh@zensar.com> * fixed the micro service build issue Signed-off-by: Ramesh <katkuri.ramesh@zensar.com> * microservice container not found fixed Signed-off-by: Ramesh <katkuri.ramesh@zensar.com> * removed airgap code Signed-off-by: Ramesh <katkuri.ramesh@zensar.com> * fixed pre-commit check issue --------- Signed-off-by: Noel Jaymon <c.noeljaymon@zensar.com> Signed-off-by: Ramesh <katkuri.ramesh@zensar.com> Co-authored-by: Ramesh Katkuri <rameshkatkuri@Rameshs-MacBook-Air.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Noel Jaymon <c.noeljaymon@zensar.com> Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * add funasr paraformer asr service impl (#1914) * add funasr paraformer asr service impl * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix requirements deps; modify ASR READMEs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add funasr-paraformer dockerfile in github workflow --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * enable HF_TOKEN to be defined in request (#1940) Signed-off-by: wwanarif <wan.abdul.hakim.b.wan.arif@intel.com> Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * fix the source of LibreOffice (#1942) Signed-off-by: zhihang <zhihangdeng@link.cuhk.edu.cn> Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * Fix CD issue and llms build failure (#1943) * Fix permissions issue Signed-off-by: ZePan110 <ze.pan@intel.com> * Fix issue Signed-off-by: ZePan110 <ze.pan@intel.com> * Test Signed-off-by: ZePan110 <ze.pan@intel.com> * Revert "Test" This reverts commit 1df372c. --------- Signed-off-by: ZePan110 <ze.pan@intel.com> Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * update vllm-ipex, boost servi performance (#1941) Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * Add openGauss support to dataprep microservice and update related doc… (#1945) * add openGauss support for dataprep Signed-off-by: sunshuang1866 <sunshuang1866@outlook.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the healthcheck for openGauss Signed-off-by: sunshuang1866 <sunshuang1866@outlook.com> * update README.md for openGauss Signed-off-by: sunshuang1866 <sunshuang1866@outlook.com> --------- Signed-off-by: sunshuang1866 <sunshuang1866@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * Add openGauss support to retrievers and update related doc… (#1949) * add openGauss support for retrievers Signed-off-by: sunshuang1866 <sunshuang1866@outlook.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: sunshuang1866 <sunshuang1866@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * update package version to fit B60 Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * enable CnClip B/16&L/14 and enable flickr30kcn dataset Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * update docs about qwen2-vl &qwen2.5-vl & cnclip, add qwen-vl configs Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix bug Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * fix bug Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * fix for doc Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * fix gradio verion Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * retrigger test Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> * added test for cnclip Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> --------- Signed-off-by: ZePan110 <ze.pan@intel.com> Signed-off-by: sunzhonghua2004 <jilong.wang@intel.com> Signed-off-by: Noel Jaymon <c.noeljaymon@zensar.com> Signed-off-by: Ramesh <katkuri.ramesh@zensar.com> Signed-off-by: wwanarif <wan.abdul.hakim.b.wan.arif@intel.com> Signed-off-by: zhihang <zhihangdeng@link.cuhk.edu.cn> Signed-off-by: sunshuang1866 <sunshuang1866@outlook.com> Co-authored-by: ZePan110 <ze.pan@intel.com> Co-authored-by: ramesh-katkuri <katkuri.ramesh@zensar.com> Co-authored-by: Ramesh Katkuri <rameshkatkuri@Rameshs-MacBook-Air.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Noel Jaymon <c.noeljaymon@zensar.com> Co-authored-by: LIU Lin <107393642+llin60@users.noreply.github.com> Co-authored-by: wanhakim <wanhakim92@gmail.com> Co-authored-by: zhihang <zhihangdeng@link.cuhk.edu.cn> Co-authored-by: linjiaojiao <jiaojiao.lin@intel.com> Co-authored-by: sunshuang1866 <sunshuang1866@outlook.com> Co-authored-by: Xueshu Wang <xueshu.wang@intel.com>
1 parent 3a50559 commit 8130a3f

File tree

18 files changed

+860
-423
lines changed

18 files changed

+860
-423
lines changed

comps/finetuning/src/integrations/xtune/README.md

Lines changed: 144 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
> [!NOTE]
66
>
7-
> - _`Xtune`_ incorporates with Llama-Factory to offer various methods for finetuning visual models (CLIP, AdaCLIP), LLM and Multi-modal models​. It makes easier to choose the method and to set fine-tuning parameters.
7+
> - _`Xtune`_ incorporates with Llama-Factory to offer various methods for finetuning visual models (CLIP, CnCLIP, AdaCLIP), LLM and Multi-modal models​. It makes easier to choose the method and to set fine-tuning parameters.
88
99
The core features include:
1010

11-
- Four finetune method for CLIP, details in [CLIP](./doc/key_features_for_clip_finetune_tool.md)
11+
- Four finetune method for CLIP & CnCLIP, details in [CLIP](./doc/key_features_for_clip_finetune_tool.md)
1212
- Three finetune method for AdaCLIP, details in [AdaCLIP](./doc/adaclip_readme.md)
1313
- Automatic hyperparameter searching enabled by Optuna [Optuna](https://github.com/optuna/optuna)
1414
- Distillation from large models with Intel ARC GPU​
@@ -59,8 +59,8 @@ Blow command is in prepare_xtune.sh. You can ignore it if you don't want to upda
5959
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
6060
# else run on A770
6161
# You can refer to https://github.com/intel/intel-extension-for-pytorch for latest command to update lib
62-
python -m pip install torch==2.5.1+cxx11.abi torchvision==0.20.1+cxx11.abi torchaudio==2.5.1+cxx11.abi --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
63-
python -m pip install intel-extension-for-pytorch==2.5.10+xpu oneccl_bind_pt==2.5.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
62+
python -m pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/xpu
63+
python -m pip install intel-extension-for-pytorch==2.8.10+xpu oneccl_bind_pt==2.8.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
6464
```
6565

6666
### 2. Install xtune on docker
@@ -106,20 +106,34 @@ then make `dataset_info.json` in your dataset directory
106106
{
107107
"caltech101": {
108108
"file_name": "caltech101.json"
109+
},
110+
"flickr30k": {
111+
"file_name": "flickr30k.json"
109112
}
110113
}
111114
```
112115

113-
## Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
114-
115-
> [!NOTE] We don't support multi-card in GUI now, will add it later.
116+
The directory structure should look like
116117

117-
When run with prepare_xtune.sh, it will automatic run ZE_AFFINITY_MASK=0 llamafactory-cli webui.
118+
```
119+
$DATA/
120+
|-- caltech-101/
121+
| |-- 101_ObjectCategories/
122+
| | split_zhou_Caltech101.json
123+
|-- flickr/
124+
| |–– flickr30k-images/
125+
| | |-- *.jpg
126+
| |-- train_texts.jsonl
127+
| |-- val_texts.jsonl
128+
| |-- test_texts.jsonl
129+
|-- dataset_info.json
130+
|-- caltech101.json
131+
|-- flickr30k.json
132+
```
118133

119-
If you see "server start successfully" in terminal.
120-
You can access in web through http://localhost:7860/
134+
## Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
121135

122-
The UI component information can be seen in doc/ui_component.md after run with prepare_xtune.sh.
136+
> [!NOTE] We don't support multi-card in GUI now, will add it later.
123137
124138
When run with prepare_xtune.sh, it will automatic run ZE_AFFINITY_MASK=0 llamafactory-cli webui.
125139

@@ -137,6 +151,58 @@ The UI component information can be seen in doc/ui_component.md after run with p
137151
Then access in web through http://localhost:7860/
138152
```
139153

154+
### GUI using guide
155+
156+
#### CLIP & CnCLIP
157+
158+
![clip ui guide](./pics/clip_ui.png)
159+
160+
- Must be set to the specified parameter values below:
161+
| Parameter | Choose Value |
162+
| ------------------- | -------------------------------------------- |
163+
| `Model name` | `CnVit-B/16` / `CnVit-L/14` /`Vit-B/16` /`Vit-L/14` |
164+
| `Model path` | Must be the detail configuration name under `src/llamafactory/clip_finetune/configs/trainers/clip_finetune/`|
165+
| `Finetuning method` | clip |
166+
| `Stage` | clip|
167+
| `Data dir` | Where you put `dataset_info.json`.|
168+
| `Method Group` |Finetune|
169+
| `clip_finetune method` | `CLIP_Adapter_hf`/ `CLIP_Bias_hf`/ `CLIP_VPT_hf` /`CLIP_Fullfinetune_hf`, must match with `Model name`(configuration name).|
170+
171+
- The matching relationship between `Model name`(configuration name) and `clip_finetune method`:
172+
173+
| clip_finetune method | `Model name`(configuration name) |
174+
| -------------------- | ------------------------------------- |
175+
| CLIP_Adapter_hf | xx_xx(e.g.,`cnvit_b16`) |
176+
| CLIP_Bias_hf | xx_xx_bias(e.g.,`cnvit_b16_bias`) |
177+
| CLIP_VPT_hf | xx_xx_prompt(e.g.,`cnvit_b16_prompt`) |
178+
| CLIP_Fullfinetune_hf | xx_xx_ori(e.g.,`cnvit_b16_ori`) |
179+
180+
#### AdaCLIP
181+
182+
![adaclip ui guide](./pics/adaclip_ui.png)
183+
184+
- Must be set to the specified parameter values below:
185+
| Parameter | Choose Value |
186+
| ------------------- | -------------------------------------------- |
187+
| `Model name` | Custom |
188+
| `Model path` | Adaclip model path|
189+
| `Finetuning method` | Adaclip|
190+
| `Stage` | Adaclip|
191+
| `Data dir` | Where you put `dataset_info.json`|
192+
193+
#### Qwen2-VL & Qwen2.5-VL
194+
195+
![qwen-vl ui guide](./pics/qwen_vl_ui.png)
196+
197+
- Must be set to the specified parameter values below:
198+
| Parameter | Choose Value |
199+
| ------------------- | -------------------------------------------- |
200+
| `Model name` | Select Qwen2-VL or Qwen2.5-VL model |
201+
| `Model path` | Will be set automatically after setting Model name, you can use your local model path,too.|
202+
| `Finetuning method` | lora|
203+
| `Stage` | Supervised Fine-Tuning|
204+
| `Data dir` | Where you put `dataset_info.json`, can use `data` as default, and update your own data in `data/dataset_info.json`|
205+
140206
## Fine-Tuning with Shell instead of GUI
141207

142208
After run `prepare_xtune.sh`, it will download all related file. And open webui as default.
@@ -154,6 +220,15 @@ cd src/llamafactory/clip_finetune
154220
# Please see README.md in src/llamafactory/clip_finetune for detail
155221
```
156222

223+
### CnCLIP
224+
225+
Please see [doc](./doc/key_features_for_clip_finetune_tool.md) for how to config feature
226+
227+
```bash
228+
cd src/llamafactory/clip_finetune
229+
# Please see README.md in src/llamafactory/clip_finetune for detail
230+
```
231+
157232
### AdaCLIP
158233

159234
```bash
@@ -164,22 +239,24 @@ cd src/llamafactory/adaclip_finetune
164239
### Qwen2-VL Training and Hyperparameter Optimization
165240

166241
```bash
167-
# Please see Qwen2-VL_README.md in doc for detail, bolow are simple use
242+
# Please see Qwen-VL_README.md in doc to use more automated fine-tuning methods and hyperparameter tuning, bolow are simple use:
168243
```
169244

170-
#### Step 1: Finetune qwen2-vl with logging eval loss
245+
#### Finetune Qwen2-VL & Qwen2.5-VL with logging eval loss
171246

172247
If you want to finetune with plotting eval loss, please set eval_strategy as steps, eval_stepsand eval_dataset:
173248

174-
```
175-
# Finetune qwen2-vl with logging eval loss
249+
##### Qwen2-VL
250+
251+
```bash
176252
export DATA='where you can find dataset_info.json'
177-
export dataset=activitynet_qa_2000_limit_20s # to point which dataset llamafactory will use
253+
#To point which dataset llamafactory will use, have to add the datasets into dataset_info.json before finetune.
254+
export dataset=activitynet_qa_2000_limit_20s
178255
export eval_dataset=activitynet_qa_val_500_limit_20s
179256
llamafactory-cli train \
180257
--stage sft \
181258
--do_train True \
182-
--model_name_or_path $models/Qwen2-VL-7B-Instruct-GPTQ-Int8 \
259+
--model_name_or_path /model/Qwen2-VL-7B-Instruct-GPTQ-Int8 \
183260
--preprocessing_num_workers 16 \
184261
--finetuning_type lora \
185262
--template qwen2_vl \
@@ -196,10 +273,10 @@ llamafactory-cli train \
196273
--max_grad_norm 1.0 \
197274
--logging_steps 10 \
198275
--save_steps 100 \
199-
--warmup_steps 100 \
276+
--warmup_steps 0 \
200277
--packing False \
201278
--report_to none \
202-
--output_dir saves/Qwen2-VL-7B-Instruct-GPTQ-Int8/lora/finetune_test_valmetrics_evalstep8 \
279+
--output_dir saves/Qwen2-VL-7B-Instruct-GPTQ-Int8/lora/finetune_qwen2vl \
203280
--bf16 True \
204281
--plot_loss True \
205282
--ddp_timeout 180000000 \
@@ -216,7 +293,54 @@ llamafactory-cli train \
216293
--lora_target all
217294
```
218295

219-
#### step 2: Evaluation metrics calculation and plotting
296+
#### Qwen2.5-VL
297+
298+
```bash
299+
export DATA='where you can find dataset_info.json'
300+
#To point which dataset llamafactory will use, have to add the datasets into dataset_info.json before finetune.
301+
export dataset=activitynet_qa_1000_limit_20s
302+
export eval_dataset=activitynet_qa_val_250_limit_20s
303+
llamafactory-cli train \
304+
--stage sft \
305+
--do_train True \
306+
--model_name_or_path /home/edgeai/wxs/workspace/models/Qwen2.5-VL-7B-Instruct \
307+
--preprocessing_num_workers 16 \
308+
--finetuning_type lora \
309+
--template qwen2_vl \
310+
--flash_attn auto \
311+
--dataset_dir $DATA \
312+
--dataset $dataset \
313+
--cutoff_len 2048 \
314+
--learning_rate 5e-05 \
315+
--num_train_epochs 2 \
316+
--max_samples 100000 \
317+
--per_device_train_batch_size 2 \
318+
--gradient_accumulation_steps 4 \
319+
--lr_scheduler_type cosine \
320+
--max_grad_norm 1.0 \
321+
--logging_steps 10 \
322+
--save_steps 100 \
323+
--warmup_steps 0 \
324+
--packing False \
325+
--report_to none \
326+
--output_dir saves/Qwen2.5-VL-7B-Instruct/lora/finetune_qwen2.5vl \
327+
--bf16 True \
328+
--plot_loss True \
329+
--ddp_timeout 180000000 \
330+
--optim adamw_torch \
331+
--video_fps 0.05 \
332+
--per_device_eval_batch_size 1 \
333+
--eval_strategy steps \
334+
--eval_steps 100 \
335+
--eval_dataset $eval_dataset \
336+
--predict_with_generate true \
337+
--lora_rank 8 \
338+
--lora_alpha 16 \
339+
--lora_dropout 0 \
340+
--lora_target all
341+
```
342+
343+
#### Calculation and Plotting of Evaluation Metrics During Fine-Tuning
220344

221345
If you want to plot eval metrics:
222346
Change `MODEL_NAME`,`EXPERIENT_NAME`,`EVAL_DATASET` as you need and run evaluation metrics calculation sctrpt:

comps/finetuning/src/integrations/xtune/clip_finetune/trainers/clip_adapter_hf.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
from torch.nn import functional as F
1313
from transformers import CLIPModel, CLIPProcessor
1414

15+
try:
16+
from transformers import ChineseCLIPModel, ChineseCLIPProcessor
17+
18+
CHINESE_CLIP_AVAILABLE = True
19+
except ImportError:
20+
CHINESE_CLIP_AVAILABLE = False
21+
1522
CUSTOM_TEMPLATES = {
1623
"OxfordPets": "a photo of a {}, a type of pet.",
1724
"OxfordFlowers": "a photo of a {}, a type of flower.",
@@ -30,22 +37,31 @@
3037
"ImageNetA": "a photo of a {}.",
3138
"ImageNetR": "a photo of a {}.",
3239
"ITC_Flickr": "{}.",
40+
"ITC_FlickrCN": "{}.",
3341
"ITC_Flickr5k": "{}.",
3442
"ITC_Mscoco": "{}.",
3543
}
3644
_MODELS = {
3745
"ViT-B/16": "openai/clip-vit-base-patch16",
3846
"ViT-B/32": "openai/clip-vit-base-patch32",
3947
"ViT-L/14": "openai/clip-vit-large-patch14",
48+
"CnViT-B/16": "OFA-Sys/chinese-clip-vit-base-patch16",
49+
"CnViT-L/14": "OFA-Sys/chinese-clip-vit-large-patch14",
4050
}
4151

4252

4353
def load_clip_to_cpu(cfg):
4454
backbone_name = cfg.MODEL.BACKBONE.NAME
4555
url = _MODELS[backbone_name]
4656

47-
model = CLIPModel.from_pretrained(url)
48-
processor = CLIPProcessor.from_pretrained(url)
57+
# Check if it's a Chinese CLIP model
58+
if backbone_name.startswith("CnViT") and CHINESE_CLIP_AVAILABLE:
59+
model = ChineseCLIPModel.from_pretrained(url)
60+
processor = ChineseCLIPProcessor.from_pretrained(url)
61+
else:
62+
model = CLIPModel.from_pretrained(url)
63+
processor = CLIPProcessor.from_pretrained(url)
64+
4965
# model.initialize_parameters()
5066

5167
return model, processor
@@ -67,7 +83,6 @@ def forward(self, x):
6783
return x
6884

6985

70-
# use clip textencode
7186
class TextEncoder(nn.Module):
7287

7388
def __init__(self, cfg, classnames, clip_model, processor):
@@ -77,6 +92,8 @@ def __init__(self, cfg, classnames, clip_model, processor):
7792
self.clip_model = clip_model
7893
self.tokenizer = processor.tokenizer
7994
self.dtype = clip_model.dtype
95+
# Check if it's Chinese CLIP model by checking model type
96+
self.is_chinese_clip = type(clip_model).__name__ == "ChineseCLIPModel"
8097

8198
def forward(self, classname=None):
8299
# for small dataset, we tokenize all prompt ------- if classname is None
@@ -88,12 +105,26 @@ def forward(self, classname=None):
88105
temp = CUSTOM_TEMPLATES[self.cfg.DATASET.NAME]
89106
prompts = [temp.format(c.replace("_", " ")) for c in classname]
90107

91-
prompts = self.tokenizer(prompts, return_tensors="pt", padding=True)["input_ids"]
108+
# Use tokenizer for both models (same interface)
109+
# Set max_length to prevent sequence length errors
110+
tokenized = self.tokenizer(prompts, return_tensors="pt", padding=True)
111+
92112
if self.cfg.TRAINER.COOP.XPU:
93-
prompts = prompts.to(self.cfg.TRAINER.COOP.XPU_ID)
113+
tokenized = {k: v.to(self.cfg.TRAINER.COOP.XPU_ID) for k, v in tokenized.items()}
114+
else:
115+
tokenized = {k: v.to(self.cfg.TRAINER.COOP.CUDA_ID) for k, v in tokenized.items()}
116+
117+
# Handle different model architectures
118+
text_outputs = self.clip_model.text_model(**tokenized)
119+
120+
if text_outputs.pooler_output is not None:
121+
# Standard CLIP has pooler_output
122+
text_features = text_outputs.pooler_output
94123
else:
95-
prompts = prompts.to(self.cfg.TRAINER.COOP.CUDA_ID)
96-
text_features = self.clip_model.text_model(prompts)[1]
124+
# Chinese CLIP doesn't have pooler_output, use last hidden state's first token
125+
# Use [CLS] token
126+
text_features = text_outputs.last_hidden_state[:, 0, :]
127+
97128
text_features = self.clip_model.text_projection(text_features)
98129
return text_features
99130

@@ -110,11 +141,22 @@ def __init__(self, cfg, classnames, clip_model, processor):
110141
self.text_encoder = TextEncoder(cfg, classnames, clip_model, processor)
111142
self.logit_scale = clip_model.logit_scale
112143
self.dtype = clip_model.dtype
113-
# init adapter
114-
self.adapter = Adapter(512, 4).to(clip_model.dtype)
144+
# Check if it's Chinese CLIP model
145+
self.is_chinese_clip = type(clip_model).__name__ == "ChineseCLIPModel"
146+
projection_dim = clip_model.visual_projection.out_features
147+
self.adapter = Adapter(projection_dim, 4).to(clip_model.dtype)
115148

116149
def forward(self, image, classname=None):
117-
image_features = self.image_encoder(image.type(self.dtype))[1]
150+
# Handle different vision model outputs
151+
vision_outputs = self.image_encoder(image.type(self.dtype))
152+
if hasattr(vision_outputs, "pooler_output") and vision_outputs.pooler_output is not None:
153+
image_features = vision_outputs.pooler_output
154+
elif isinstance(vision_outputs, tuple) and len(vision_outputs) > 1:
155+
image_features = vision_outputs[1] # pooled output
156+
else:
157+
# Fallback: use last hidden state
158+
image_features = vision_outputs.last_hidden_state.mean(dim=1)
159+
118160
image_features = self.visual_projection(image_features)
119161
# apply adapter in ViT
120162
x = self.adapter(image_features)
@@ -206,7 +248,16 @@ def get_text_embeds(self, text):
206248
return text_features
207249

208250
def get_img_embeds(self, image):
209-
image_features = self.model.image_encoder(image.type(self.model.dtype))[1]
251+
# Handle different vision model outputs
252+
vision_outputs = self.model.image_encoder(image.type(self.model.dtype))
253+
if hasattr(vision_outputs, "pooler_output") and vision_outputs.pooler_output is not None:
254+
image_features = vision_outputs.pooler_output
255+
elif isinstance(vision_outputs, tuple) and len(vision_outputs) > 1:
256+
image_features = vision_outputs[1] # pooled output
257+
else:
258+
# Fallback: use last hidden state
259+
image_features = vision_outputs.last_hidden_state.mean(dim=1)
260+
210261
image_features = self.model.visual_projection(image_features)
211262
x = self.model.adapter(image_features)
212263

0 commit comments

Comments
 (0)