Skip to content

Commit 53d4c87

Browse files
feat(datasets): allow multipart uploads for large datasets (#384)
This attempts to fall back to a multipart upload strategy with presigned URLs in the event that a dataset is larger than 500MB
1 parent c0e0625 commit 53d4c87

File tree

1 file changed

+108
-10
lines changed

1 file changed

+108
-10
lines changed

gradient/commands/datasets.py

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
import Queue as queue
1313
from xml.etree import ElementTree
1414
from urllib.parse import urlparse
15+
from ..api_sdk.clients import http_client
16+
from ..api_sdk.config import config
17+
from ..cli_constants import CLI_PS_CLIENT_NAME
1518

1619
import halo
1720
import requests
@@ -557,24 +560,114 @@ def update_status():
557560

558561
class PutDatasetFilesCommand(BaseDatasetFilesCommand):
559562

560-
@classmethod
561-
def _put(cls, path, url, content_type):
563+
# @classmethod
564+
def _put(self, path, url, content_type, dataset_version_id=None, key=None):
562565
size = os.path.getsize(path)
563566
with requests.Session() as session:
564567
headers = {'Content-Type': content_type}
565568

566569
try:
567-
if size > 0:
570+
if size <= 0:
571+
headers.update({'Content-Size': '0'})
572+
r = session.put(url, data='', headers=headers, timeout=5)
573+
# for files under half a GB
574+
elif size <= (10e8) / 2:
568575
with open(path, 'rb') as f:
569576
r = session.put(
570577
url, data=f, headers=headers, timeout=5)
578+
# # for chonky files, use a multipart upload
571579
else:
572-
headers.update({'Content-Size': '0'})
573-
r = session.put(url, data='', headers=headers, timeout=5)
574-
575-
cls.validate_s3_response(r)
580+
# Chunks need to be at least 5MB or AWS throws an
581+
# EntityTooSmall error; we'll arbitrarily choose a
582+
# 15MB chunksize
583+
#
584+
# Note also that AWS limits the max number of chunkc
585+
# in a multipart upload to 10000, so this setting
586+
# currently enforces a hard limit on 150GB per file.
587+
#
588+
# We can dynamically assign a larger part size if needed,
589+
# but for the majority of use cases we should be fine
590+
# as-is
591+
part_minsize = int(15e6)
592+
dataset_id, _, version = dataset_version_id.partition(":")
593+
mpu_url = f'/datasets/{dataset_id}/versions/{version}/s3/preSignedUrls'
594+
595+
api_client = http_client.API(
596+
api_url=config.CONFIG_HOST,
597+
api_key=self.api_key,
598+
ps_client_name=CLI_PS_CLIENT_NAME
599+
)
600+
601+
mpu_create_res = api_client.post(
602+
url=mpu_url,
603+
json={
604+
'datasetId': dataset_id,
605+
'version': version,
606+
'calls': [{
607+
'method': 'createMultipartUpload',
608+
'params': {'Key': key}
609+
}]
610+
}
611+
)
612+
mpu_data = json.loads(mpu_create_res.text)[0]['url']
613+
614+
parts = []
615+
with open(path, 'rb') as f:
616+
# we +2 the number of parts since we're doing floor
617+
# division, which will cut off any trailing part
618+
# less than the part_minsize, AND we want to 1-index
619+
# our range to match what AWS expects for part
620+
# numbers
621+
for part in range(1, (size // part_minsize) + 2):
622+
presigned_url_res = api_client.post(
623+
url=mpu_url,
624+
json={
625+
'datasetId': dataset_id,
626+
'version': version,
627+
'calls': [{
628+
'method': 'uploadPart',
629+
'params': {
630+
'Key': key,
631+
'UploadId': mpu_data['UploadId'],
632+
'PartNumber': part
633+
}
634+
}]
635+
}
636+
)
637+
638+
presigned_url = json.loads(
639+
presigned_url_res.text
640+
)[0]['url']
641+
642+
chunk = f.read(part_minsize)
643+
part_res = session.put(
644+
presigned_url,
645+
data=chunk,
646+
timeout=5)
647+
etag = part_res.headers['ETag'].replace('"', '')
648+
parts.append({'ETag': etag, 'PartNumber': part})
649+
650+
r = api_client.post(
651+
url=mpu_url,
652+
json={
653+
'datasetId': dataset_id,
654+
'version': version,
655+
'calls': [{
656+
'method': 'completeMultipartUpload',
657+
'params': {
658+
'Key': key,
659+
'UploadId': mpu_data['UploadId'],
660+
'MultipartUpload': {'Parts': parts}
661+
}
662+
}]
663+
}
664+
)
665+
666+
self.validate_s3_response(r)
576667
except requests.exceptions.ConnectionError as e:
577-
return cls.report_connection_error(e)
668+
return self.report_connection_error(e)
669+
except Exception as e:
670+
return e
578671

579672
@staticmethod
580673
def _list_files(source_path):
@@ -599,8 +692,13 @@ def _sign_and_put(self, dataset_version_id, pool, results, update_status):
599692

600693
for pre_signed, result in zip(pre_signeds, results):
601694
update_status()
602-
pool.put(self._put, url=pre_signed.url,
603-
path=result['path'], content_type=result['mimetype'])
695+
pool.put(
696+
self._put,
697+
url=pre_signed.url,
698+
path=result['path'],
699+
content_type=result['mimetype'],
700+
dataset_version_id=dataset_version_id,
701+
key=result['key'])
604702

605703
def execute(self, dataset_version_id, source_paths, target_path):
606704
self.assert_supported(dataset_version_id)

0 commit comments

Comments
 (0)