Skip to content

Commit dc501f0

Browse files
committed
refactor(upload): Optimize SliceUploadManager with singleflight for session management and improve session loading logic
1 parent 05882df commit dc501f0

File tree

1 file changed

+15
-122
lines changed

1 file changed

+15
-122
lines changed

internal/fs/sliceup.go

Lines changed: 15 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/OpenListTeam/OpenList/v4/internal/model/tables"
1919
"github.com/OpenListTeam/OpenList/v4/internal/op"
2020
"github.com/OpenListTeam/OpenList/v4/internal/stream"
21+
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
2122
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
2223
"github.com/google/uuid"
2324
"github.com/pkg/errors"
@@ -27,7 +28,8 @@ import (
2728

2829
// SliceUploadManager 分片上传管理器
2930
type SliceUploadManager struct {
30-
cache sync.Map // TaskID -> *SliceUploadSession
31+
sessionG singleflight.Group[*SliceUploadSession]
32+
cache sync.Map // TaskID -> *SliceUploadSession
3133
}
3234

3335
// SliceUploadSession 分片上传会话
@@ -112,7 +114,7 @@ func (m *SliceUploadManager) CreateSession(ctx context.Context, storage driver.D
112114
log.Error(err)
113115
return nil, errors.WithStack(err)
114116
}
115-
user, _ := ctx.Value(conf.UserKey).(*model.User)
117+
user := ctx.Value(conf.UserKey).(*model.User)
116118

117119
// 生成唯一的TaskID
118120
taskID := uuid.New().String()
@@ -130,9 +132,7 @@ func (m *SliceUploadManager) CreateSession(ctx context.Context, storage driver.D
130132
Overwrite: req.Overwrite,
131133
ActualPath: actualPath,
132134
AsTask: req.AsTask,
133-
}
134-
if user != nil {
135-
createsu.UserID = user.ID
135+
UserID: user.ID,
136136
}
137137
log.Infof("storage mount path %s", storage.GetStorage().MountPath)
138138

@@ -185,50 +185,26 @@ func (m *SliceUploadManager) CreateSession(ctx context.Context, storage driver.D
185185

186186
// getOrLoadSession 获取或加载会话,提高代码复用性
187187
func (m *SliceUploadManager) getOrLoadSession(taskID string) (*SliceUploadSession, error) {
188-
sa, loaded := m.cache.LoadOrStore(taskID, (*SliceUploadSession)(nil))
189-
if !loaded {
190-
// 首次加载,需要从数据库获取
191-
su, err := db.GetSliceUploadByTaskID(taskID)
192-
if err != nil {
193-
m.cache.Delete(taskID) // 清理无效的 key
194-
return nil, errors.WithMessagef(err, "failed get slice upload [%s]", taskID)
195-
}
196-
session := &SliceUploadSession{
197-
SliceUpload: su,
188+
session, err, _ := m.sessionG.Do(taskID, func() (*SliceUploadSession, error) {
189+
if s, ok := m.cache.Load(taskID); ok {
190+
return s.(*SliceUploadSession), nil
198191
}
199-
m.cache.Store(taskID, session)
200-
return session, nil
201-
}
202-
203-
// 缓存中存在,但可能是nil值,需要检查
204-
if sa == nil {
205-
// 说明之前存储了nil,需要重新从数据库加载
192+
// 首次加载,需要从数据库获取
206193
su, err := db.GetSliceUploadByTaskID(taskID)
207194
if err != nil {
208-
m.cache.Delete(taskID)
209195
return nil, errors.WithMessagef(err, "failed get slice upload [%s]", taskID)
210196
}
211-
session := &SliceUploadSession{
197+
s := &SliceUploadSession{
212198
SliceUpload: su,
213199
}
214-
m.cache.Store(taskID, session)
215-
return session, nil
216-
}
217-
218-
session := sa.(*SliceUploadSession)
219-
// 刷新数据库状态以确保数据一致性
220-
if freshSu, err := db.GetSliceUploadByTaskID(taskID); err == nil {
221-
session.mutex.Lock()
222-
session.SliceUpload = freshSu
223-
session.mutex.Unlock()
224-
}
225-
return session, nil
200+
m.cache.Store(taskID, s)
201+
return s, nil
202+
})
203+
return session, err
226204
}
227205

228206
// UploadSlice 流式上传分片 - 支持流式上传,避免表单上传的内存占用
229207
func (m *SliceUploadManager) UploadSlice(ctx context.Context, storage driver.Driver, req *reqres.UploadSliceReq, reader io.Reader) error {
230-
var err error
231-
232208
session, err := m.getOrLoadSession(req.TaskID)
233209
if err != nil {
234210
log.Errorf("failed to get session: %+v", err)
@@ -283,40 +259,7 @@ func (m *SliceUploadManager) UploadSlice(ctx context.Context, storage driver.Dri
283259
// 根据存储类型处理分片上传
284260
switch s := storage.(type) {
285261
case driver.ISliceUpload:
286-
log.Info("SliceUpload support")
287-
// 对于支持原生分片上传的驱动,我们需要将流数据缓存到临时文件中
288-
// 以支持重试和断点续传场景
289-
if err := session.ensureTmpFile(); err != nil {
290-
log.Error("ensureTmpFile error for native slice upload", req, err)
291-
return err
292-
}
293-
294-
// 将流数据写入临时文件的指定位置
295-
sw := &sliceWriter{
296-
file: session.tmpFile,
297-
offset: int64(req.SliceNum) * int64(session.SliceSize),
298-
}
299-
writtenBytes, err := utils.CopyWithBuffer(sw, reader)
300-
if err != nil {
301-
log.Error("Copy to temp file error for native slice upload", req, err)
302-
return err
303-
}
304-
log.Debugf("Written %d bytes to temp file for slice %d", writtenBytes, req.SliceNum)
305-
306-
// 从临时文件读取数据进行上传
307-
sliceSize := session.SliceSize
308-
if req.SliceNum == session.SliceCnt-1 {
309-
// 最后一个分片,计算实际大小
310-
sliceSize = session.Size - int64(req.SliceNum)*int64(session.SliceSize)
311-
}
312-
313-
sliceReader := &sliceReader{
314-
file: session.tmpFile,
315-
offset: int64(req.SliceNum) * int64(session.SliceSize),
316-
size: sliceSize,
317-
}
318-
319-
if err := s.SliceUpload(ctx, session.SliceUpload, req.SliceNum, sliceReader); err != nil {
262+
if err := s.SliceUpload(ctx, session.SliceUpload, req.SliceNum, reader); err != nil {
320263
log.Error("SliceUpload error", req, err)
321264
return err
322265
}
@@ -578,56 +521,6 @@ func (sw *sliceWriter) Write(p []byte) (int, error) {
578521
return n, err
579522
}
580523

581-
// sliceReader 用于从临时文件中读取指定分片的数据,支持断点续传
582-
type sliceReader struct {
583-
file *os.File
584-
offset int64
585-
size int64
586-
position int64 // 当前读取位置(相对于分片开始)
587-
}
588-
589-
// Read implements io.Reader interface
590-
func (sr *sliceReader) Read(p []byte) (int, error) {
591-
if sr.position >= sr.size {
592-
return 0, io.EOF
593-
}
594-
595-
// 计算实际可读取的字节数
596-
remaining := sr.size - sr.position
597-
if int64(len(p)) > remaining {
598-
p = p[:remaining]
599-
}
600-
601-
n, err := sr.file.ReadAt(p, sr.offset+sr.position)
602-
sr.position += int64(n)
603-
return n, err
604-
}
605-
606-
// Seek implements io.Seeker interface,支持重试场景
607-
func (sr *sliceReader) Seek(offset int64, whence int) (int64, error) {
608-
var newPos int64
609-
switch whence {
610-
case io.SeekStart:
611-
newPos = offset
612-
case io.SeekCurrent:
613-
newPos = sr.position + offset
614-
case io.SeekEnd:
615-
newPos = sr.size + offset
616-
default:
617-
return 0, fmt.Errorf("invalid whence value: %d", whence)
618-
}
619-
620-
if newPos < 0 {
621-
return 0, fmt.Errorf("negative position: %d", newPos)
622-
}
623-
if newPos > sr.size {
624-
newPos = sr.size
625-
}
626-
627-
sr.position = newPos
628-
return newPos, nil
629-
}
630-
631524
// recoverIncompleteUploads 恢复重启后未完成的上传任务
632525
func (m *SliceUploadManager) recoverIncompleteUploads() {
633526
defer func() {

0 commit comments

Comments
 (0)