Skip to content

Commit 26cd4dd

Browse files
committed
修改百度云盘所有情况都支持断点续传
1 parent 8bbdb27 commit 26cd4dd

File tree

1 file changed

+108
-53
lines changed

1 file changed

+108
-53
lines changed

drivers/baidu_netdisk/driver.go

Lines changed: 108 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import (
55
"crypto/md5"
66
"encoding/hex"
77
"errors"
8+
"fmt"
89
"io"
910
"net/url"
1011
"os"
1112
stdpath "path"
1213
"strconv"
14+
"strings"
1315
"time"
1416

1517
"github.com/OpenListTeam/OpenList/v4/drivers/base"
@@ -31,6 +33,8 @@ type BaiduNetdisk struct {
3133
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
3234
}
3335

36+
var ErrUploadIDExpired = errors.New("uploadid expired")
37+
3438
func (d *BaiduNetdisk) Config() driver.Config {
3539
return config
3640
}
@@ -168,18 +172,15 @@ func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream mo
168172
if err != nil {
169173
return nil, err
170174
}
171-
// 修复时间,具体原因见 Put 方法注释的 **注意**
175+
// 修复时间
172176
newFile.Ctime = stream.CreateTime().Unix()
173177
newFile.Mtime = stream.ModTime().Unix()
174178
return fileToObj(newFile), nil
175179
}
176180

177-
// Put
178-
//
179-
// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。
180-
// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致
181+
// Put 文件上传,支持断点续传和 uploadid 过期重试
181182
func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
182-
// rapid upload
183+
// 尝试秒传
183184
if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil {
184185
return newObj, nil
185186
}
@@ -212,9 +213,8 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
212213
lastBlockSize = sliceSize
213214
}
214215

215-
//cal md5 for first 256k data
216+
// 计算 md5
216217
const SliceSize int64 = 256 * utils.KB
217-
// cal md5
218218
blockList := make([]string, 0, count)
219219
byteSize := sliceSize
220220
fileMd5H := md5.New()
@@ -244,7 +244,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
244244
}
245245
if tmpF != nil {
246246
if written != streamSize {
247-
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize)
247+
return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize)
248248
}
249249
_, err = tmpF.Seek(0, io.SeekStart)
250250
if err != nil {
@@ -258,13 +258,11 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
258258
mtime := stream.ModTime().Unix()
259259
ctime := stream.CreateTime().Unix()
260260

261-
// step.1 预上传
262-
// 尝试获取之前的进度
261+
// step.1 尝试读取已保存进度
263262
precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5)
264263
if !ok {
265-
params := map[string]string{
266-
"method": "precreate",
267-
}
264+
// 没有进度,走预上传
265+
params := map[string]string{"method": "precreate"}
268266
form := map[string]string{
269267
"path": path,
270268
"size": strconv.FormatInt(streamSize, 10),
@@ -276,60 +274,108 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
276274
"slice-md5": sliceMd5,
277275
}
278276
joinTime(form, ctime, mtime)
279-
280-
log.Debugf("[baidu_netdisk] precreate data: %s", form)
281277
_, err = d.postForm("/xpan/file", params, form, &precreateResp)
282278
if err != nil {
283279
return nil, err
284280
}
285-
log.Debugf("%+v", precreateResp)
286281
if precreateResp.ReturnType == 2 {
287-
//rapid upload, since got md5 match from baidu server
288-
// 修复时间,具体原因见 Put 方法注释的 **注意**
289282
precreateResp.File.Ctime = ctime
290283
precreateResp.File.Mtime = mtime
291284
return fileToObj(precreateResp.File), nil
292285
}
293286
}
287+
294288
// step.2 上传分片
295-
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
296-
retry.Attempts(1),
297-
retry.Delay(time.Second),
298-
retry.DelayType(retry.BackOffDelay))
299-
300-
for i, partseq := range precreateResp.BlockList {
301-
if utils.IsCanceled(upCtx) {
302-
break
289+
uploadLoop:
290+
for attempt := 0; attempt < 2; attempt++ {
291+
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
292+
retry.Attempts(1),
293+
retry.Delay(time.Second),
294+
retry.DelayType(retry.BackOffDelay))
295+
296+
cacheReaderAt, okReaderAt := cache.(io.ReaderAt)
297+
if !okReaderAt {
298+
return nil, fmt.Errorf("cache does not implement io.ReaderAt")
303299
}
304300

305-
i, partseq, offset, byteSize := i, partseq, int64(partseq)*sliceSize, sliceSize
306-
if partseq+1 == count {
307-
byteSize = lastBlockSize
308-
}
309-
threadG.Go(func(ctx context.Context) error {
310-
params := map[string]string{
311-
"method": "upload",
312-
"access_token": d.AccessToken,
313-
"type": "tmpfile",
314-
"path": path,
315-
"uploadid": precreateResp.Uploadid,
316-
"partseq": strconv.Itoa(partseq),
301+
totalParts := len(precreateResp.BlockList)
302+
for i, partseq := range precreateResp.BlockList {
303+
if utils.IsCanceled(upCtx) || partseq < 0 {
304+
continue
317305
}
318-
err := d.uploadSlice(ctx, params, stream.GetName(),
319-
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)))
320-
if err != nil {
321-
return err
306+
i, partseq := i, partseq
307+
offset, size := int64(partseq)*sliceSize, sliceSize
308+
if partseq+1 == count {
309+
size = lastBlockSize
322310
}
323-
up(float64(threadG.Success()) * 100 / float64(len(precreateResp.BlockList)))
324-
precreateResp.BlockList[i] = -1
325-
return nil
326-
})
327-
}
328-
if err = threadG.Wait(); err != nil {
329-
// 如果属于用户主动取消,则保存上传进度
311+
threadG.Go(func(ctx context.Context) error {
312+
params := map[string]string{
313+
"method": "upload",
314+
"access_token": d.AccessToken,
315+
"type": "tmpfile",
316+
"path": path,
317+
"uploadid": precreateResp.Uploadid,
318+
"partseq": strconv.Itoa(partseq),
319+
}
320+
section := io.NewSectionReader(cacheReaderAt, offset, size)
321+
err := d.uploadSlice(ctx, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
322+
if err != nil {
323+
return err
324+
}
325+
precreateResp.BlockList[i] = -1
326+
// 进度
327+
done := 0
328+
for _, v := range precreateResp.BlockList {
329+
if v < 0 {
330+
done++
331+
}
332+
}
333+
if totalParts > 0 {
334+
up(float64(done) * 100.0 / float64(totalParts))
335+
}
336+
return nil
337+
})
338+
}
339+
340+
err = threadG.Wait()
341+
if err == nil {
342+
break uploadLoop
343+
}
344+
345+
// 保存进度(所有错误都会保存)
346+
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
347+
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
348+
330349
if errors.Is(err, context.Canceled) {
331-
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
350+
return nil, err
351+
}
352+
if errors.Is(err, ErrUploadIDExpired) {
353+
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
354+
// 重新 precreate(所有分片都要重传)
355+
params := map[string]string{"method": "precreate"}
356+
form := map[string]string{
357+
"path": path,
358+
"size": strconv.FormatInt(streamSize, 10),
359+
"isdir": "0",
360+
"autoinit": "1",
361+
"rtype": "3",
362+
"block_list": blockListStr,
363+
}
364+
joinTime(form, ctime, mtime)
365+
var newPre PrecreateResp
366+
_, err2 := d.postForm("/xpan/file", params, form, &newPre)
367+
if err2 != nil {
368+
return nil, err2
369+
}
370+
if newPre.ReturnType == 2 {
371+
newPre.File.Ctime = ctime
372+
newPre.File.Mtime = mtime
373+
return fileToObj(newPre.File), nil
374+
}
375+
precreateResp = &newPre
376+
// 覆盖掉旧的进度
332377
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
378+
continue uploadLoop
333379
}
334380
return nil, err
335381
}
@@ -340,9 +386,10 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
340386
if err != nil {
341387
return nil, err
342388
}
343-
// 修复时间,具体原因见 Put 方法注释的 **注意**
344389
newFile.Ctime = ctime
345390
newFile.Mtime = mtime
391+
// 上传成功清理进度
392+
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
346393
return fileToObj(newFile), nil
347394
}
348395

@@ -358,8 +405,16 @@ func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string
358405
log.Debugln(res.RawResponse.Status + res.String())
359406
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
360407
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
408+
respStr := res.String()
409+
lower := strings.ToLower(respStr)
410+
if strings.Contains(lower, "uploadid") && (strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) {
411+
return ErrUploadIDExpired
412+
}
361413
if errCode != 0 || errNo != 0 {
362-
return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String())
414+
if strings.Contains(lower, "invalid uploadid") {
415+
return ErrUploadIDExpired
416+
}
417+
return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", res.String())
363418
}
364419
return nil
365420
}

0 commit comments

Comments
 (0)