Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 125 additions & 61 deletions drivers/baidu_netdisk/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"io"
"net/url"
"os"
stdpath "path"
"strconv"
"strings"
"time"

"github.com/OpenListTeam/OpenList/v4/drivers/base"
Expand All @@ -31,6 +33,8 @@ type BaiduNetdisk struct {
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
}

var ErrUploadIDExpired = errors.New("uploadid expired")

func (d *BaiduNetdisk) Config() driver.Config {
return config
}
Expand Down Expand Up @@ -174,12 +178,12 @@ func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream mo
return fileToObj(newFile), nil
}

// Put
// Put 文件上传,支持断点续传和 uploadid 过期重试
//
// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。
// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致
func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
// rapid upload
// rapid upload 尝试秒传
if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil {
return newObj, nil
}
Expand Down Expand Up @@ -212,9 +216,8 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
lastBlockSize = sliceSize
}

//cal md5 for first 256k data
//cal md5 for first 256k data 计算 md5
const SliceSize int64 = 256 * utils.KB
// cal md5
blockList := make([]string, 0, count)
byteSize := sliceSize
fileMd5H := md5.New()
Expand Down Expand Up @@ -244,7 +247,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
}
if tmpF != nil {
if written != streamSize {
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize)
return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize)
}
_, err = tmpF.Seek(0, io.SeekStart)
if err != nil {
Expand All @@ -258,78 +261,94 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
mtime := stream.ModTime().Unix()
ctime := stream.CreateTime().Unix()

// step.1 预上传
// 尝试获取之前的进度
// step.1 尝试读取已保存进度
precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5)
if !ok {
params := map[string]string{
"method": "precreate",
}
form := map[string]string{
"path": path,
"size": strconv.FormatInt(streamSize, 10),
"isdir": "0",
"autoinit": "1",
"rtype": "3",
"block_list": blockListStr,
"content-md5": contentMd5,
"slice-md5": sliceMd5,
}
joinTime(form, ctime, mtime)

log.Debugf("[baidu_netdisk] precreate data: %s", form)
_, err = d.postForm("/xpan/file", params, form, &precreateResp)
// 没有进度,走预上传
precreateResp, err = d.precreate(ctx, path, streamSize, blockListStr, contentMd5, sliceMd5, ctime, mtime)
if err != nil {
return nil, err
}
log.Debugf("%+v", precreateResp)
if precreateResp.ReturnType == 2 {
//rapid upload, since got md5 match from baidu server
// 修复时间,具体原因见 Put 方法注释的 **注意**
precreateResp.File.Ctime = ctime
precreateResp.File.Mtime = mtime
return fileToObj(precreateResp.File), nil
}
}

// step.2 上传分片
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
retry.Attempts(1),
retry.Delay(time.Second),
retry.DelayType(retry.BackOffDelay))

for i, partseq := range precreateResp.BlockList {
if utils.IsCanceled(upCtx) {
break
uploadLoop:
for attempt := 0; attempt < 2; attempt++ {
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
retry.Attempts(1),
retry.Delay(time.Second),
retry.DelayType(retry.BackOffDelay))

cacheReaderAt, okReaderAt := cache.(io.ReaderAt)
if !okReaderAt {
return nil, fmt.Errorf("cache object must implement io.ReaderAt interface for upload operations")
}

i, partseq, offset, byteSize := i, partseq, int64(partseq)*sliceSize, sliceSize
if partseq+1 == count {
byteSize = lastBlockSize
}
threadG.Go(func(ctx context.Context) error {
params := map[string]string{
"method": "upload",
"access_token": d.AccessToken,
"type": "tmpfile",
"path": path,
"uploadid": precreateResp.Uploadid,
"partseq": strconv.Itoa(partseq),
totalParts := len(precreateResp.BlockList)
doneParts := 0 // 新增计数器

for i, partseq := range precreateResp.BlockList {
if utils.IsCanceled(upCtx) || partseq < 0 {
continue
}
err := d.uploadSlice(ctx, params, stream.GetName(),
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)))
if err != nil {
return err
i, partseq := i, partseq
offset, size := int64(partseq)*sliceSize, sliceSize
if partseq+1 == count {
size = lastBlockSize
}
up(float64(threadG.Success()) * 100 / float64(len(precreateResp.BlockList)))
precreateResp.BlockList[i] = -1
return nil
})
}
if err = threadG.Wait(); err != nil {
// 如果属于用户主动取消,则保存上传进度
threadG.Go(func(ctx context.Context) error {
params := map[string]string{
"method": "upload",
"access_token": d.AccessToken,
"type": "tmpfile",
"path": path,
"uploadid": precreateResp.Uploadid,
"partseq": strconv.Itoa(partseq),
}
section := io.NewSectionReader(cacheReaderAt, offset, size)
err := d.uploadSlice(ctx, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
if err != nil {
return err
}
precreateResp.BlockList[i] = -1
doneParts++ // 直接递增计数器
if totalParts > 0 {
up(float64(doneParts) * 100.0 / float64(totalParts))
}
return nil
})
}

err = threadG.Wait()
if err == nil {
break uploadLoop
}

// 保存进度(所有错误都会保存)
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)

if errors.Is(err, context.Canceled) {
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
return nil, err
}
if errors.Is(err, ErrUploadIDExpired) {
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
// 重新 precreate(所有分片都要重传)
newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime)
if err2 != nil {
return nil, err2
}
if newPre.ReturnType == 2 {
return fileToObj(newPre.File), nil
}
precreateResp = newPre
// 覆盖掉旧的进度
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
continue uploadLoop
}
return nil, err
}
Expand All @@ -343,9 +362,46 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
// 修复时间,具体原因见 Put 方法注释的 **注意**
newFile.Ctime = ctime
newFile.Mtime = mtime
// 上传成功清理进度
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
return fileToObj(newFile), nil
}

// precreate 执行预上传操作,支持首次上传和 uploadid 过期重试
func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize int64, blockListStr, contentMd5, sliceMd5 string, ctime, mtime int64) (*PrecreateResp, error) {
params := map[string]string{"method": "precreate"}
form := map[string]string{
"path": path,
"size": strconv.FormatInt(streamSize, 10),
"isdir": "0",
"autoinit": "1",
"rtype": "3",
"block_list": blockListStr,
}

// 只有在首次上传时才包含 content-md5 和 slice-md5
if contentMd5 != "" && sliceMd5 != "" {
form["content-md5"] = contentMd5
form["slice-md5"] = sliceMd5
}

joinTime(form, ctime, mtime)

var precreateResp PrecreateResp
_, err := d.postForm("/xpan/file", params, form, &precreateResp)
if err != nil {
return nil, err
}

// 修复时间,具体原因见 Put 方法注释的 **注意**
if precreateResp.ReturnType == 2 {
precreateResp.File.Ctime = ctime
precreateResp.File.Mtime = mtime
}

return &precreateResp, nil
}

func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string, fileName string, file io.Reader) error {
res, err := base.RestyClient.R().
SetContext(ctx).
Expand All @@ -358,8 +414,16 @@ func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string
log.Debugln(res.RawResponse.Status + res.String())
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
respStr := res.String()
lower := strings.ToLower(respStr)
// 合并 uploadid 过期检测逻辑
if strings.Contains(lower, "uploadid") &&
(strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) {
return ErrUploadIDExpired
}

if errCode != 0 || errNo != 0 {
return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String())
return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", res.String())
}
return nil
}
Expand Down