Skip to content

Commit ecddd5c

Browse files
committed
add gradio demo
1 parent 0275b50 commit ecddd5c

File tree

5 files changed

+272
-9
lines changed

5 files changed

+272
-9
lines changed

skycaptioner_v1/README.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# SkyCaptioner-V1: A Structural Video Captioning Model
22

33
<p align="center">
4-
📑 <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a> · 👋 <a href="https://www.skyreels.ai/home?utm_campaign=github_SkyReels_V2" target="_blank">Playground</a> · 💬 <a href="https://discord.gg/PwM6NYtccQ" target="_blank">Discord</a> · 🤗 <a href="https://huggingface.co/Skywork/SkyCaptioner-V1" target="_blank">Hugging Face</a> · 🤖 <a href="https://modelscope.cn/collections/SkyReels-V2-f665650130b144">ModelScope</a></a>
4+
📑 <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a> · 👋 <a href="https://www.skyreels.ai/home?utm_campaign=github_SkyReels_V2" target="_blank">Playground</a> · 💬 <a href="https://discord.gg/PwM6NYtccQ" target="_blank">Discord</a> · 🤗 <a href="https://huggingface.co/Skywork/SkyCaptioner-V1" target="_blank">Hugging Face</a> · 🤖 <a href="https://modelscope.cn/collections/SkyReels-V2-f665650130b144">ModelScope</a> · 🚀 <a href="https://huggingface.co/spaces/Skywork/SkyCaptioner-V1">Demo</a>
55
</p>
66

77
---
88

99
Welcome to the SkyCaptioner-V1 repository! Here, you'll find the structural video captioning model weights and inference code for our video captioner that labels the video data efficiently and comprehensively.
1010

1111
## 🔥🔥🔥 News!!
12-
12+
* May 07, 2025: 🚀 Added a web demo implementation based on Gradio and the [online demo](https://huggingface.co/spaces/Skywork/SkyCaptioner-V1) is now available!
1313
* Apr 21, 2025: 👋 We release the [vllm](https://github.com/vllm-project/vllm) batch inference code for SkyCaptioner-V1 Model and caption fusion inference code.
1414
* Apr 21, 2025: 👋 We release the first shot-aware video captioning model [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1). For more details, please check our [paper](https://arxiv.org/pdf/2504.13074).
1515

@@ -20,7 +20,7 @@ Welcome to the SkyCaptioner-V1 repository! Here, you'll find the structural vide
2020
- [x] Checkpoints
2121
- [x] Batch Inference Code
2222
- [x] Caption Fusion Method
23-
- [ ] Web Demo (Gradio)
23+
- [x] Web Demo (Gradio)
2424

2525
## 🌟 Overview
2626

@@ -241,6 +241,22 @@ python scripts/vllm_fusion_caption.py \
241241
> **Note**:
242242
> - If you want to get i2v caption, just change the `--task t2v` to `--task i2v` in your Command.
243243
244+
#### Gradio Web Demo
245+
Launch the Gradio web demo for SkyCaptioner-V1:
246+
```shell
247+
export SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
248+
python scripts/gradio_struct_caption.py \
249+
--skycaptioner_model_path ${SkyCaptioner_V1_Model_PATH}
250+
```
251+
252+
Launch the Gradio web demo for Caption Fusion:
253+
```shell
254+
export LLM_MODEL_PATH="/path/to/your_local_model_path2"
255+
python scripts/gradio_fusion_caption.py \
256+
--fusioncaptioner_model_path ${LLM_MODEL_PATH} \
257+
```
258+
259+
244260
## Acknowledgements
245261

246262
We would like to thank the contributors of <a href="https://github.com/QwenLM/Qwen2.5-VL">Qwen2.5-VL</a>, <a href="https://github.com/bytedance/tarsier">tarsier2</a> and <a href="https://github.com/vllm-project/vllm">vllm</a> repositories, for their open research and contributions.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import json
2+
import argparse
3+
import pandas as pd
4+
import gradio as gr
5+
6+
from vllm import LLM, SamplingParams
7+
8+
from vllm_fusion_caption import StructuralCaptionDataset
9+
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("--fusioncaptioner_model_path", default=None, type=str)
12+
parser.add_argument("--tensor_parallel_size", type=int, default=2)
13+
args = parser.parse_args()
14+
15+
example_input = """
16+
{
17+
"subjects": [
18+
{
19+
"TYPES": {
20+
"type": "Human",
21+
"sub_type": "Woman"
22+
},
23+
"appearance": "Long, straight black hair with bangs, wearing a sparkling choker necklace and a dark-colored top or dress with a visible strap over her shoulder.",
24+
"action": "A woman wearing a sparkling choker necklace and earrings is sitting in a car, looking to her left and speaking. A man, dressed in a suit, is sitting next to her, attentively watching her.",
25+
"expression": "The individual in the video exhibits a neutral facial expression, characterized by slightly open lips and a gentle, soft-focus gaze. There are no noticeable signs of sadness or distress evident in their demeanor.",
26+
"position": "Seated in the foreground of the car, facing slightly to the right.",
27+
"is_main_subject": true
28+
},
29+
{
30+
"TYPES": {
31+
"type": "Human",
32+
"sub_type": "Man"
33+
},
34+
"appearance": "Short hair, wearing a dark-colored suit with a white shirt.",
35+
"action": "",
36+
"expression": "",
37+
"position": "Seated in the background of the car, facing the woman.",
38+
"is_main_subject": false
39+
}
40+
],
41+
"shot_type": "close_up",
42+
"shot_angle": "eye_level",
43+
"shot_position": "side_view",
44+
"camera_motion": "",
45+
"environment": "Interior of a car with a dark color scheme.",
46+
"lighting": "Soft and natural lighting, suggesting daytime."
47+
}
48+
"""
49+
50+
class FusionCaptioner:
51+
def __init__(self, model_path, tensor_parallel_size):
52+
self.model = LLM(model=model_path,
53+
gpu_memory_utilization=0.9,
54+
max_model_len=4096,
55+
tensor_parallel_size=tensor_parallel_size)
56+
self.sampling_params = SamplingParams(
57+
temperature=0.1,
58+
max_tokens=512,
59+
stop=['\n\n']
60+
)
61+
self.model_path = model_path
62+
63+
def __call__(self, structural_caption, task='t2v'):
64+
if isinstance(structural_caption, dict):
65+
structural_caption = json.dumps(structural_caption, ensure_ascii=False)
66+
else:
67+
structural_caption = json.dumps(json.loads(structural_caption), ensure_ascii=False)
68+
meta = pd.DataFrame([structural_caption], columns=['structural_caption'])
69+
print(f'structural_caption: {structural_caption}')
70+
print(f'task: {task}')
71+
dataset = StructuralCaptionDataset(meta, self.model_path, task)
72+
_, fusion_by_llm, text, original_text, camera_movement = dataset[0]
73+
llm_original_texts = []
74+
if not fusion_by_llm:
75+
caption = original_text + " " + camera_movement
76+
return caption
77+
try:
78+
outputs = self.model.generate([text], self.sampling_params, use_tqdm=False)
79+
result = outputs[0].outputs[0].text
80+
except Exception as e:
81+
result = llm_original_texts
82+
83+
llm_caption = result + " " + camera_movement
84+
return llm_caption
85+
86+
def main():
87+
fusion_captioner = FusionCaptioner(args.fusioncaptioner_model_path, args.tensor_parallel_size)
88+
89+
def fusion_caption(structural_caption, task):
90+
caption = fusion_captioner(structural_caption, task)
91+
return caption
92+
93+
with gr.Blocks() as demo:
94+
gr.Markdown(
95+
"""
96+
<h1 style="text-align: center; font-size: 2em;">SkyCaptioner</h1>
97+
""",
98+
elem_id="header"
99+
)
100+
101+
with gr.Row():
102+
with gr.Column(visible=True):
103+
with gr.Row():
104+
json_input = gr.Code(
105+
label="Structural Caption",
106+
language="json",
107+
lines=25,
108+
interactive=True
109+
)
110+
with gr.Row():
111+
task_input = gr.Radio(
112+
label="Task",
113+
choices=["t2v", "i2v"],
114+
value="t2v",
115+
interactive=True
116+
)
117+
118+
with gr.Column(visible=True):
119+
text_output = gr.Textbox(
120+
label="Fusion Caption",
121+
lines=25,
122+
interactive=False,
123+
autoscroll=True
124+
)
125+
126+
gr.Button("Generate").click(
127+
fn=fusion_caption,
128+
inputs=[json_input, task_input],
129+
outputs=text_output
130+
)
131+
with gr.Row():
132+
gr.Examples(
133+
examples=[
134+
[example_input, "t2v"],
135+
],
136+
inputs=[json_input, task_input],
137+
label="Example Input"
138+
)
139+
demo.launch(
140+
server_name="0.0.0.0",
141+
server_port=7863,
142+
share=False
143+
)
144+
145+
if __name__ == '__main__':
146+
main()
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import json
2+
import argparse
3+
import pandas as pd
4+
import gradio as gr
5+
from vllm import LLM, SamplingParams
6+
from vllm_struct_caption import VideoTextDataset
7+
8+
9+
class StructCaptioner:
10+
def __init__(self, model_path, tensor_parallel_size):
11+
self.model = LLM(model=model_path,
12+
gpu_memory_utilization=0.6,
13+
max_model_len=31920,
14+
tensor_parallel_size=tensor_parallel_size)
15+
16+
self.model_path = model_path
17+
self.sampling_params = SamplingParams(temperature=0.05, max_tokens=2048)
18+
19+
def __call__(self, video_path):
20+
meta = pd.DataFrame([video_path], columns=['path'])
21+
dataset = VideoTextDataset(meta, self.model_path)
22+
item = dataset[0]['input']
23+
batch_user_inputs = [{
24+
'prompt': item['prompt'],
25+
'multi_modal_data':{'video': item['multi_modal_data']['video'][0]},
26+
}]
27+
outputs = self.model.generate(batch_user_inputs, self.sampling_params, use_tqdm=False)
28+
caption = outputs[0].outputs[0].text
29+
caption = json.loads(caption)
30+
caption = json.dumps(caption, indent=4, ensure_ascii=False)
31+
return caption
32+
33+
34+
def main():
35+
parser = argparse.ArgumentParser()
36+
parser.add_argument("--skycaptioner_model_path", required=True, type=str)
37+
parser.add_argument("--tensor_parallel_size", type=int, default=2)
38+
args = parser.parse_args()
39+
40+
struct_captioner = StructCaptioner(args.skycaptioner_model_path, args.tensor_parallel_size)
41+
def generate_caption(video_path):
42+
caption = struct_captioner(video_path)
43+
return caption
44+
45+
with gr.Blocks() as demo:
46+
gr.Markdown(
47+
"""
48+
<h1 style="text-align: center; font-size: 2em;">SkyCaptioner</h1>
49+
""",
50+
elem_id="header"
51+
)
52+
53+
with gr.Row():
54+
with gr.Column(visible=True, scale=0.5):
55+
with gr.Row():
56+
video_input = gr.Video(
57+
label="Upload Video",
58+
interactive=True,
59+
format="mp4",
60+
)
61+
62+
with gr.Column(visible=True):
63+
json_output = gr.Code(
64+
label="Caption",
65+
language="json",
66+
lines=25,
67+
interactive=False
68+
)
69+
70+
gr.Button("Generate").click(
71+
fn=generate_caption,
72+
inputs=video_input,
73+
outputs=json_output
74+
)
75+
76+
gr.Examples(
77+
examples=[
78+
["./examples/data/1.mp4"],
79+
["./examples/data/2.mp4"],
80+
],
81+
inputs=video_input,
82+
label="Example Videos"
83+
)
84+
85+
demo.launch(
86+
server_name="0.0.0.0",
87+
server_port=7862,
88+
share=False
89+
)
90+
91+
if __name__ == '__main__':
92+
main()

skycaptioner_v1/scripts/vllm_fusion_caption.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,15 @@
6464

6565

6666
class StructuralCaptionDataset(torch.utils.data.Dataset):
67-
def __init__(self, input_csv, model_path):
68-
self.meta = pd.read_csv(input_csv)
69-
self.task = args.task
67+
def __init__(self, input_csv, model_path, task=None):
68+
if isinstance(input_csv, pd.DataFrame):
69+
self.meta = input_csv
70+
else:
71+
self.meta = pd.read_csv(input_csv)
72+
if task is None:
73+
self.task = args.task
74+
else:
75+
self.task = task
7076
self.system_prompt = SYSTEM_PROMPT_T2V if self.task == 't2v' else SYSTEM_PROMPT_I2V
7177
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
7278

@@ -146,8 +152,8 @@ def clean_struct_caption(self, struct_caption, task):
146152

147153

148154
shot_type = struct_caption.get('shot_type', '').replace('_', ' ')
149-
if shot_type not in SHOT_TYPE_LIST:
150-
struct_caption['shot_type'] = ''
155+
# if shot_type not in SHOT_TYPE_LIST:
156+
# struct_caption['shot_type'] = ''
151157

152158
new_struct_caption = {
153159
'num_subjects': len(subjects),

skycaptioner_v1/scripts/vllm_struct_caption.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
class VideoTextDataset(torch.utils.data.Dataset):
1919
def __init__(self, csv_path, model_path):
20-
self.meta = pd.read_csv(csv_path)
20+
if isinstance(csv_path, pd.DataFrame):
21+
self.meta = csv_path
22+
else:
23+
self.meta = pd.read_csv(csv_path)
2124
self._path = 'path'
2225
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
2326
self.processor = AutoProcessor.from_pretrained(model_path)

0 commit comments

Comments
 (0)