Skip to content

Commit 5f5b942

Browse files
authored
Merge pull request #1988 from feitianbubu/pr/add-sora
新增Sora视频渠道
2 parents e24f13a + b880094 commit 5f5b942

File tree

12 files changed

+481
-2
lines changed

12 files changed

+481
-2
lines changed

common/gin.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package common
33
import (
44
"bytes"
55
"io"
6+
"mime/multipart"
67
"net/http"
78
"one-api/constant"
89
"strings"
@@ -113,3 +114,26 @@ func ApiSuccess(c *gin.Context, data any) {
113114
"data": data,
114115
})
115116
}
117+
118+
func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
119+
requestBody, err := GetRequestBody(c)
120+
if err != nil {
121+
return nil, err
122+
}
123+
124+
contentType := c.Request.Header.Get("Content-Type")
125+
boundary := ""
126+
if idx := strings.Index(contentType, "boundary="); idx != -1 {
127+
boundary = contentType[idx+9:]
128+
}
129+
130+
reader := multipart.NewReader(bytes.NewReader(requestBody), boundary)
131+
form, err := reader.ReadForm(32 << 20) // 32 MB max memory
132+
if err != nil {
133+
return nil, err
134+
}
135+
136+
// Reset request body
137+
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
138+
return form, nil
139+
}

constant/channel.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const (
5252
ChannelTypeVidu = 52
5353
ChannelTypeSubmodel = 53
5454
ChannelTypeDoubaoVideo = 54
55+
ChannelTypeSora = 55
5556
ChannelTypeDummy // this one is only for count, do not add any channel after this
5657

5758
)
@@ -112,6 +113,7 @@ var ChannelBaseURLs = []string{
112113
"https://api.vidu.cn", //52
113114
"https://llm.submodel.ai", //53
114115
"https://ark.cn-beijing.volces.com", //54
116+
"https://api.openai.com", //55
115117
}
116118

117119
var ChannelTypeNames = map[int]string{
@@ -166,6 +168,7 @@ var ChannelTypeNames = map[int]string{
166168
ChannelTypeVidu: "Vidu",
167169
ChannelTypeSubmodel: "Submodel",
168170
ChannelTypeDoubaoVideo: "DoubaoVideo",
171+
ChannelTypeSora: "Sora",
169172
}
170173

171174
func GetChannelTypeName(channelType int) string {

controller/task_video.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
4747
if adaptor == nil {
4848
return fmt.Errorf("video adaptor not found")
4949
}
50+
info := &relaycommon.RelayInfo{}
51+
info.ChannelMeta = &relaycommon.ChannelMeta{
52+
ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
53+
}
54+
adaptor.Init(info)
5055
for _, taskId := range taskIds {
5156
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
5257
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))

controller/video_proxy.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package controller
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"one-api/logger"
8+
"one-api/model"
9+
"time"
10+
11+
"github.com/gin-gonic/gin"
12+
)
13+
14+
func VideoProxy(c *gin.Context) {
15+
taskID := c.Param("task_id")
16+
if taskID == "" {
17+
c.JSON(http.StatusBadRequest, gin.H{
18+
"error": gin.H{
19+
"message": "task_id is required",
20+
"type": "invalid_request_error",
21+
},
22+
})
23+
return
24+
}
25+
26+
task, exists, err := model.GetByOnlyTaskId(taskID)
27+
if err != nil {
28+
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
29+
c.JSON(http.StatusInternalServerError, gin.H{
30+
"error": gin.H{
31+
"message": "Failed to query task",
32+
"type": "server_error",
33+
},
34+
})
35+
return
36+
}
37+
if !exists || task == nil {
38+
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %s", taskID, err.Error()))
39+
c.JSON(http.StatusNotFound, gin.H{
40+
"error": gin.H{
41+
"message": "Task not found",
42+
"type": "invalid_request_error",
43+
},
44+
})
45+
return
46+
}
47+
48+
if task.Status != model.TaskStatusSuccess {
49+
c.JSON(http.StatusBadRequest, gin.H{
50+
"error": gin.H{
51+
"message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
52+
"type": "invalid_request_error",
53+
},
54+
})
55+
return
56+
}
57+
58+
channel, err := model.CacheGetChannel(task.ChannelId)
59+
if err != nil {
60+
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel %d: %s", task.ChannelId, err.Error()))
61+
c.JSON(http.StatusInternalServerError, gin.H{
62+
"error": gin.H{
63+
"message": "Failed to retrieve channel information",
64+
"type": "server_error",
65+
},
66+
})
67+
return
68+
}
69+
baseURL := channel.GetBaseURL()
70+
if baseURL == "" {
71+
baseURL = "https://api.openai.com"
72+
}
73+
videoURL := fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
74+
75+
client := &http.Client{
76+
Timeout: 60 * time.Second,
77+
}
78+
79+
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, videoURL, nil)
80+
if err != nil {
81+
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request for %s: %s", videoURL, err.Error()))
82+
c.JSON(http.StatusInternalServerError, gin.H{
83+
"error": gin.H{
84+
"message": "Failed to create proxy request",
85+
"type": "server_error",
86+
},
87+
})
88+
return
89+
}
90+
91+
req.Header.Set("Authorization", "Bearer "+channel.Key)
92+
93+
resp, err := client.Do(req)
94+
if err != nil {
95+
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
96+
c.JSON(http.StatusBadGateway, gin.H{
97+
"error": gin.H{
98+
"message": "Failed to fetch video content",
99+
"type": "server_error",
100+
},
101+
})
102+
return
103+
}
104+
defer resp.Body.Close()
105+
106+
if resp.StatusCode != http.StatusOK {
107+
logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
108+
c.JSON(http.StatusBadGateway, gin.H{
109+
"error": gin.H{
110+
"message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
111+
"type": "server_error",
112+
},
113+
})
114+
return
115+
}
116+
117+
for key, values := range resp.Header {
118+
for _, value := range values {
119+
c.Writer.Header().Add(key, value)
120+
}
121+
}
122+
123+
c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
124+
c.Writer.WriteHeader(resp.StatusCode)
125+
_, err = io.Copy(c.Writer, resp.Body)
126+
if err != nil {
127+
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
128+
}
129+
}

dto/openai_response.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,16 @@ type Usage struct {
233233
Cost any `json:"cost,omitempty"`
234234
}
235235

236+
type OpenAIVideoResponse struct {
237+
Id string `json:"id" example:"file-abc123"`
238+
Object string `json:"object" example:"file"`
239+
Bytes int64 `json:"bytes" example:"120000"`
240+
CreatedAt int64 `json:"created_at" example:"1677610602"`
241+
ExpiresAt int64 `json:"expires_at" example:"1677614202"`
242+
Filename string `json:"filename" example:"mydata.jsonl"`
243+
Purpose string `json:"purpose" example:"fine-tune"`
244+
}
245+
236246
type InputTokenDetails struct {
237247
CachedTokens int `json:"cached_tokens"`
238248
CachedCreationTokens int `json:"-"`

middleware/distributor.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,18 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
165165
}
166166
c.Set("platform", string(constant.TaskPlatformSuno))
167167
c.Set("relay_mode", relayMode)
168+
} else if strings.Contains(c.Request.URL.Path, "/v1/videos") {
169+
//curl https://api.openai.com/v1/videos \
170+
// -H "Authorization: Bearer $OPENAI_API_KEY" \
171+
// -F "model=sora-2" \
172+
// -F "prompt=A calico cat playing a piano on stage"
173+
// -F input_reference="@image.jpg"
174+
relayMode := relayconstant.RelayModeUnknown
175+
if c.Request.Method == http.MethodPost {
176+
relayMode = relayconstant.RelayModeVideoSubmit
177+
modelRequest.Model = c.PostForm("model")
178+
}
179+
c.Set("relay_mode", relayMode)
168180
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
169181
relayMode := relayconstant.RelayModeUnknown
170182
if c.Request.Method == http.MethodPost {

0 commit comments

Comments
 (0)