99 "mime"
1010 "mime/multipart"
1111 "net/http"
12- "os"
13- "path/filepath"
1412 "strconv"
1513
1614 "github.com/dstackai/dstack/runner/internal/api"
@@ -19,6 +17,9 @@ import (
1917 "github.com/dstackai/dstack/runner/internal/schemas"
2018)
2119
20+ // TODO: set some reasonable value; (optional) make configurable
21+ const maxBodySize = math .MaxInt64
22+
2223func (s * Server ) healthcheckGetHandler (w http.ResponseWriter , r * http.Request ) (interface {}, error ) {
2324 return & schemas.HealthcheckResponse {
2425 Service : "dstack-runner" ,
@@ -84,13 +85,16 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request
8485 return nil , & api.Error {Status : http .StatusBadRequest , Msg : "missing boundary" }
8586 }
8687
87- r .Body = http .MaxBytesReader (w , r .Body , math . MaxInt64 )
88+ r .Body = http .MaxBytesReader (w , r .Body , maxBodySize )
8889 formReader := multipart .NewReader (r .Body , boundary )
8990 part , err := formReader .NextPart ()
90- if errors .Is (err , io .EOF ) {
91- return nil , & api.Error {Status : http .StatusBadRequest , Msg : "empty form" }
92- }
9391 if err != nil {
92+ if errors .Is (err , io .EOF ) {
93+ return nil , & api.Error {Status : http .StatusBadRequest , Msg : "empty form" }
94+ }
95+ if isMaxBytesError (err ) {
96+ return nil , & api.Error {Status : http .StatusRequestEntityTooLarge }
97+ }
9498 return nil , fmt .Errorf ("read multipart form: %w" , err )
9599 }
96100 defer func () { _ = part .Close () }()
@@ -106,8 +110,11 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request
106110 if archiveId == "" {
107111 return nil , & api.Error {Status : http .StatusBadRequest , Msg : "missing file name" }
108112 }
109- if err := s .executor .AddFileArchive (archiveId , part ); err != nil {
110- return nil , fmt .Errorf ("add file archive: %w" , err )
113+ if err := s .executor .WriteFileArchive (archiveId , part ); err != nil {
114+ if isMaxBytesError (err ) {
115+ return nil , & api.Error {Status : http .StatusRequestEntityTooLarge }
116+ }
117+ return nil , fmt .Errorf ("write file archive: %w" , err )
111118 }
112119 if _ , err := formReader .NextPart (); ! errors .Is (err , io .EOF ) {
113120 return nil , & api.Error {Status : http .StatusBadRequest , Msg : "extra form field(s)" }
@@ -123,21 +130,17 @@ func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (
123130 return nil , & api.Error {Status : http .StatusConflict }
124131 }
125132
126- r .Body = http .MaxBytesReader (w , r .Body , math .MaxInt64 )
127- codePath := filepath .Join (s .tempDir , "code" ) // todo random name?
128- file , err := os .Create (codePath )
129- if err != nil {
130- return nil , fmt .Errorf ("create code file: %w" , err )
131- }
132- defer func () { _ = file .Close () }()
133- if _ , err = io .Copy (file , r .Body ); err != nil {
134- if err .Error () == "http: request body too large" {
133+ r .Body = http .MaxBytesReader (w , r .Body , maxBodySize )
134+
135+ if err := s .executor .WriteRepoBlob (r .Body ); err != nil {
136+ if isMaxBytesError (err ) {
135137 return nil , & api.Error {Status : http .StatusRequestEntityTooLarge }
136138 }
137139 return nil , fmt .Errorf ("copy request body: %w" , err )
138140 }
139141
140- s .executor .SetCodePath (codePath )
142+ s .executor .SetRunnerState (executor .WaitRun )
143+
141144 return nil , nil
142145}
143146
@@ -181,3 +184,8 @@ func (s *Server) stopPostHandler(w http.ResponseWriter, r *http.Request) (interf
181184 s .stop ()
182185 return nil , nil
183186}
187+
188+ func isMaxBytesError (err error ) bool {
189+ var maxBytesError * http.MaxBytesError
190+ return errors .As (err , & maxBytesError )
191+ }
0 commit comments