Skip to content

Commit 3370bb8

Browse files
committed
fix: integ tests
1 parent 02358a4 commit 3370bb8

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4646
TRAINING_DATASET_MODEL_DICT = {
4747
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
4848
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
49+
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
4950
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
5051
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
5152
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),

tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def package_artifacts(self):
7777

7878
self.model_name = self.get_model_name()
7979

80+
if self.script_uri is None:
81+
print("No script uri provided. Not performing prepack")
82+
return self.model_uri
83+
8084
cache_bucket_uri = f"s3://{get_test_artifact_bucket()}"
8185
repacked_model_uri = "/".join(
8286
[
@@ -147,16 +151,26 @@ def get_model_name(self) -> str:
147151
return f"{non_timestamped_name}{self.suffix}"
148152

149153
def create_model(self) -> None:
154+
primary_container = {
155+
"Image": self.image_uri,
156+
"Mode": "SingleModel",
157+
"Environment": self.environment_variables,
158+
}
159+
if self.repacked_model_uri.endswith(".tar.gz"):
160+
primary_container["ModelDataUrl"] = self.repacked_model_uri
161+
else:
162+
primary_container["ModelDataSource"] = {
163+
"S3DataSource": {
164+
"S3Uri": self.repacked_model_uri,
165+
"S3DataType": "S3Prefix",
166+
"CompressionType": "None",
167+
}
168+
}
150169
self.sagemaker_client.create_model(
151170
ModelName=self.model_name,
152171
EnableNetworkIsolation=True,
153172
ExecutionRoleArn=self.execution_role,
154-
PrimaryContainer={
155-
"Image": self.image_uri,
156-
"ModelDataUrl": self.repacked_model_uri,
157-
"Mode": "SingleModel",
158-
"Environment": self.environment_variables,
159-
},
173+
PrimaryContainer=primary_container,
160174
)
161175

162176
def create_endpoint_config(self) -> None:

tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ def test_jumpstart_inference_retrieve_functions(setup):
4646
tolerate_vulnerable_model=True,
4747
)
4848

49-
script_uri = script_uris.retrieve(
50-
model_id=model_id,
51-
model_version=model_version,
52-
script_scope="inference",
53-
tolerate_vulnerable_model=True,
54-
)
55-
5649
model_uri = model_uris.retrieve(
5750
model_id=model_id,
5851
model_version=model_version,
@@ -68,7 +61,7 @@ def test_jumpstart_inference_retrieve_functions(setup):
6861

6962
inference_job = InferenceJobLauncher(
7063
image_uri=image_uri,
71-
script_uri=script_uri,
64+
script_uri=None,
7265
model_uri=model_uri,
7366
instance_type=instance_type,
7467
base_name="catboost",

0 commit comments

Comments
 (0)