@@ -8,13 +8,15 @@ import (
88 "io"
99 "math"
1010 "net/url"
11+ "os"
1112 stdpath "path"
1213 "strconv"
1314 "time"
1415
1516 "golang.org/x/sync/semaphore"
1617
1718 "github.com/alist-org/alist/v3/drivers/base"
19+ "github.com/alist-org/alist/v3/internal/conf"
1820 "github.com/alist-org/alist/v3/internal/driver"
1921 "github.com/alist-org/alist/v3/internal/errs"
2022 "github.com/alist-org/alist/v3/internal/model"
@@ -176,18 +178,28 @@ func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream mo
176178//
177179// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。
178180// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致
179- func (d * BaiduNetdisk ) Put (ctx context.Context , dstDir model.Obj , stream model.FileStreamer , up driver.UpdateProgress ) (model.Obj , error ) {
181+ func (d * BaiduNetdisk ) Put (ctx context.Context , dstDir model.Obj , file model.FileStreamer , up driver.UpdateProgress ) (model.Obj , error ) {
180182 // rapid upload
181- if newObj , err := d .PutRapid (ctx , dstDir , stream ); err == nil {
183+ if newObj , err := d .PutRapid (ctx , dstDir , file ); err == nil {
182184 return newObj , nil
183185 }
184186
185- tempFile , err := stream .CacheFullInTempFile ()
186- if err != nil {
187- return nil , err
187+ var readerAt = file .GetCache ()
188+ var (
189+ tmpF * os.File
190+ err error
191+ )
192+ writers := make ([]io.Writer , 0 , 4 )
193+ if _ , ok := readerAt .(io.ReaderAt ); ! ok {
194+ tmpF , err = os .CreateTemp (conf .Conf .TempDir , "file-*" )
195+ if err != nil {
196+ return nil , err
197+ }
198+ writers = append (writers , tmpF )
199+ readerAt = tmpF
188200 }
189201
190- streamSize := stream .GetSize ()
202+ streamSize := file .GetSize ()
191203 sliceSize := d .getSliceSize (streamSize )
192204 count := int (math .Max (math .Ceil (float64 (streamSize )/ float64 (sliceSize )), 1 ))
193205 lastBlockSize := streamSize % sliceSize
@@ -204,6 +216,8 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
204216 sliceMd5H := md5 .New ()
205217 sliceMd5H2 := md5 .New ()
206218 slicemd5H2Write := utils .LimitWriter (sliceMd5H2 , SliceSize )
219+ writers = append (writers , fileMd5H , sliceMd5H , slicemd5H2Write )
220+ written := int64 (0 )
207221
208222 for i := 1 ; i <= count ; i ++ {
209223 if utils .IsCanceled (ctx ) {
@@ -212,19 +226,32 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
212226 if i == count {
213227 byteSize = lastBlockSize
214228 }
215- _ , err := utils .CopyWithBufferN (io .MultiWriter (fileMd5H , sliceMd5H , slicemd5H2Write ), tempFile , byteSize )
229+ n , err := utils .CopyWithBufferN (io .MultiWriter (fileMd5H , sliceMd5H , slicemd5H2Write ), file , byteSize )
230+ written += n
216231 if err != nil && err != io .EOF {
217232 return nil , err
218233 }
219234 blockList = append (blockList , hex .EncodeToString (sliceMd5H .Sum (nil )))
220235 sliceMd5H .Reset ()
221236 }
237+ if tmpF != nil {
238+ if written != streamSize {
239+ _ = os .Remove (tmpF .Name ())
240+ return nil , errs .NewErr (err , "CreateTempFile failed, incoming stream actual size= %d, expect = %d " , written , streamSize )
241+ }
242+ _ , err = tmpF .Seek (0 , io .SeekStart )
243+ if err != nil {
244+ _ = os .Remove (tmpF .Name ())
245+ return nil , errs .NewErr (err , "CreateTempFile failed, can't seek to 0 " )
246+ }
247+ file .SetTmpFile (tmpF )
248+ }
222249 contentMd5 := hex .EncodeToString (fileMd5H .Sum (nil ))
223250 sliceMd5 := hex .EncodeToString (sliceMd5H2 .Sum (nil ))
224251 blockListStr , _ := utils .Json .MarshalToString (blockList )
225- path := stdpath .Join (dstDir .GetPath (), stream .GetName ())
226- mtime := stream .ModTime ().Unix ()
227- ctime := stream .CreateTime ().Unix ()
252+ path := stdpath .Join (dstDir .GetPath (), file .GetName ())
253+ mtime := file .ModTime ().Unix ()
254+ ctime := file .CreateTime ().Unix ()
228255
229256 // step.1 预上传
230257 // 尝试获取之前的进度
@@ -284,8 +311,8 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
284311 "uploadid" : precreateResp .Uploadid ,
285312 "partseq" : strconv .Itoa (partseq ),
286313 }
287- err := d .uploadSlice (ctx , params , stream .GetName (),
288- driver .NewLimitedUploadStream (ctx , io .NewSectionReader (tempFile , offset , byteSize )))
314+ err := d .uploadSlice (ctx , params , file .GetName (),
315+ driver .NewLimitedUploadStream (ctx , io .NewSectionReader (readerAt , offset , byteSize )))
289316 if err != nil {
290317 return err
291318 }
0 commit comments