Skip to content

Commit 2f6573c

Browse files
author
Julian Bright
committed
Add compute type argument
1 parent 6a3ec9c commit 2f6573c

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

sagemaker_studio_image_build/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def delete_zip_file(bucket, key):
6464
s3.delete_object(Bucket=bucket, Key=key)
6565

6666

67-
def build_image(repository, role, bucket, extra_args, log=True):
67+
def build_image(repository, role, bucket, compute_type, extra_args, log=True):
6868
bucket, key = upload_zip_file(repository, bucket, " ".join(extra_args))
6969
try:
7070
from sagemaker_studio_image_build.codebuild import TempCodeBuildProject
7171

72-
with TempCodeBuildProject(f"{bucket}/{key}", role, repository=repository) as p:
72+
with TempCodeBuildProject(f"{bucket}/{key}", role, repository=repository, compute_type=compute_type) as p:
7373
p.build(log)
7474
finally:
7575
delete_zip_file(bucket, key)

sagemaker_studio_image_build/cli.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ def validate_args(args, extra_args):
2525
f"The value of the -f/file argument [{file_value}] is outside the working directory [{os.getcwd()}]"
2626
)
2727

28+
# Validate arg compute_type
29+
if args.compute_type:
30+
if not args.compute_type in ['BUILD_GENERAL1_SMALL', 'BUILD_GENERAL1_MEDIUM',
31+
'BUILD_GENERAL1_LARGE', 'BUILD_GENERAL1_LARGE', 'BUILD_GENERAL1_LARGE', 'BUILD_GENERAL1_2XLARGE']:
32+
raise ValueError(
33+
f'Error parsing reference: "{args.repository}" is not a valid repository/tag'
34+
)
35+
2836

2937
def get_role(args):
3038
if args.role:
@@ -50,7 +58,7 @@ def build_image(args, extra_args):
5058
validate_args(args, extra_args)
5159

5260
builder.build_image(
53-
args.repository, get_role(args), args.bucket, extra_args, log=not args.no_logs
61+
args.repository, get_role(args), args.bucket, args.compute_type, extra_args, log=not args.no_logs
5462
)
5563

5664

@@ -70,6 +78,15 @@ def main():
7078
"--repository",
7179
help="The ECR repository:tag for the image (default: sagemaker-studio-${domain_id}:latest)",
7280
)
81+
build_parser.add_argument(
82+
"--image",
83+
help="The ECR repository:tag for the image (default: sagemaker-studio-${domain_id}:latest)",
84+
)
85+
build_parser.add_argument(
86+
"--compute-type",
87+
help="The code build compute type (default: BUILD_GENERAL1_SMALL)",
88+
default="BUILD_GENERAL1_SMALL"
89+
)
7390
build_parser.add_argument(
7491
"--role",
7592
help=f"The IAM role name for CodeBuild to use (default: the Studio execution role).",

sagemaker_studio_image_build/codebuild.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111

1212

1313
class TempCodeBuildProject:
14-
def __init__(self, s3_location, role, repository=None):
14+
def __init__(self, s3_location, role, repository=None, compute_type=None):
1515
self.s3_location = s3_location
1616
self.role = role
1717

1818
self.session = boto3.session.Session()
1919
self.domain_id, self.user_profile_name = self._get_studio_metadata()
2020
self.repo_name = None
21+
self.compute_type = compute_type or 'BUILD_GENERAL1_SMALL'
2122

2223
if repository:
2324
self.repo_name, self.tag = repository.split(":", maxsplit=1)
@@ -62,7 +63,7 @@ def __enter__(self):
6263
"environment": {
6364
"type": "LINUX_CONTAINER",
6465
"image": "aws/codebuild/standard:4.0",
65-
"computeType": "BUILD_GENERAL1_SMALL",
66+
"computeType": self.compute_type,
6667
"environmentVariables": [
6768
{"name": "AWS_DEFAULT_REGION", "value": region},
6869
{"name": "AWS_ACCOUNT_ID", "value": account},

0 commit comments

Comments
 (0)