Skip to content

Commit d71fe71

Browse files
author
Pravali Uppugunduri
committed
Auto-capture requirements
1 parent d18e41b commit d71fe71

File tree

3 files changed

+265
-56
lines changed

3 files changed

+265
-56
lines changed

albatross_test.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import sys
2+
print(sys.path)
3+
sys.path.append("/home/upravali/telemetry/sagemaker-python-sdk/src/sagemaker")
4+
sys.path.append('/home/upravali/langchain/langchain-aws/libs/aws/')
5+
print("Updated sys.path: ", sys.path)
6+
7+
import json
8+
import os
9+
import time
10+
11+
from sagemaker.serve.builder.model_builder import ModelBuilder
12+
from sagemaker.serve.builder.schema_builder import SchemaBuilder
13+
from sagemaker.serve.spec.inference_spec import InferenceSpec
14+
import langchain_aws
15+
import langchain_core
16+
from langchain_aws import ChatBedrockConverse
17+
from langchain_core.prompts import ChatPromptTemplate
18+
from langchain_core.output_parsers import StrOutputParser
19+
20+
INPUTS = {
21+
'CPU': {
22+
'INFERENCE_IMAGE': '763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.4.0-cpu-py311-ubuntu22.04-sagemaker',
23+
'INSTANCE_TYPE': 'ml.m5.xlarge'
24+
},
25+
'GPU': {
26+
'INFERENCE_IMAGE': '763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.4.0-gpu-py311-cu124-ubuntu22.04-sagemaker',
27+
'INSTANCE_TYPE': 'ml.g5.xlarge'
28+
},
29+
'SERVICE': {
30+
'ROLE': 'arn:aws:iam::971812153697:role/upravali-test-role'
31+
}
32+
}
33+
34+
def deploy(device):
35+
36+
class CustomerInferenceSpec(InferenceSpec):
37+
38+
def load(self, model_dir):
39+
from langchain_aws import ChatBedrockConverse
40+
from langchain_core.prompts import ChatPromptTemplate
41+
from langchain_core.output_parsers import StrOutputParser
42+
return \
43+
ChatPromptTemplate.from_messages(
44+
[
45+
(
46+
"system",
47+
"You are a verbose assistant that gives long-winded responses at least 500 words long for every comment/question.",
48+
),
49+
("human", "{input}"),
50+
]
51+
) | \
52+
ChatBedrockConverse(
53+
model = 'anthropic.claude-3-sonnet-20240229-v1:0',
54+
temperature = 0,
55+
region_name = 'us-west-2'
56+
) | \
57+
StrOutputParser()
58+
59+
def invoke(self, x, model):
60+
return model.invoke({'input': x['input']}) if x['stream'].lower() != 'true' \
61+
else model.stream({'input': x['input']})
62+
63+
64+
65+
model = ModelBuilder(
66+
##################################################################
67+
# can be service or customer who defines these
68+
##################################################################
69+
name = f'model-{int(time.time())}',
70+
71+
##################################################################
72+
# service should define these
73+
##################################################################
74+
image_uri = INPUTS[device]['INFERENCE_IMAGE'],
75+
env_vars = {
76+
'TS_DISABLE_TOKEN_AUTHORIZATION' : 'true' # ABSOLUTELY NECESSARY
77+
},
78+
79+
##################################################################
80+
# customer should define these
81+
##################################################################
82+
schema_builder = SchemaBuilder(
83+
json.dumps({
84+
'stream': 'true',
85+
'input': 'hello'
86+
}),
87+
"<EOF>"
88+
),
89+
inference_spec = CustomerInferenceSpec(), # Won't be pickled correctly if Python version locally and DLC don't match
90+
dependencies = {
91+
"auto": True,
92+
# 'requirements' : './inference/code/requirements2.txt'
93+
},
94+
role_arn = INPUTS['SERVICE']['ROLE']
95+
).build()
96+
endpoint = model.deploy(
97+
initial_instance_count = 1,
98+
instance_type = INPUTS[device]['INSTANCE_TYPE'],
99+
)
100+
return (model, endpoint)
101+
102+
103+
###################################################################################################
104+
#
105+
#
106+
# PoC DEMO CODE ONLY
107+
#
108+
# Note: invoke vs invoke_stream matters
109+
###################################################################################################
110+
def invoke(endpoint, x):
111+
res = endpoint.predict(x)
112+
return res
113+
114+
def invoke_stream(endpoint, x):
115+
res = endpoint.predict_stream(x)
116+
print(str(res)) # Generator
117+
return res
118+
119+
def clean(model, endpoint):
120+
try:
121+
endpoint.delete_endpoint()
122+
except Exception as e:
123+
print(e)
124+
pass
125+
126+
try:
127+
model.delete_model()
128+
except Exception as e:
129+
print(e)
130+
pass
131+
132+
def main(device):
133+
print("before deploying")
134+
model, endpoint = deploy(device)
135+
print("after deploying")
136+
137+
while True:
138+
x = input(f">>> ")
139+
if x == 'exit':
140+
break
141+
try:
142+
if json.loads(x)['stream'].lower() == 'true':
143+
for chunk in invoke_stream(endpoint, x):
144+
print(
145+
str(chunk, encoding = 'utf-8'),
146+
end = "",
147+
flush = True
148+
)
149+
print()
150+
else:
151+
print(invoke(endpoint, x))
152+
except Exception as e:
153+
print(e)
154+
155+
clean(model, endpoint)
156+
157+
if __name__ == '__main__':
158+
os.environ['AWS_DEFAULT_REGION'] = 'us-west-2'
159+
main('CPU')
Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,19 @@
1-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License"). You
4-
# may not use this file except in compliance with the License. A copy of
5-
# the License is located at
6-
#
7-
# http://aws.amazon.com/apache2.0/
8-
#
9-
# or in the "license" file accompanying this file. This file is
10-
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11-
# ANY KIND, either express or implied. See the License for the specific
12-
# language governing permissions and limitations under the License.
13-
"""SageMaker model builder dependency managing module.
14-
15-
This must be kept independent of SageMaker PySDK
16-
"""
17-
18-
from __future__ import absolute_import
19-
20-
from pathlib import Path
211
import logging
222
import subprocess
233
import sys
244
import re
5+
from pathlib import Path
256

267
_SUPPORTED_SUFFIXES = [".txt"]
27-
# TODO : Move PKL_FILE_NAME to common location
288
PKL_FILE_NAME = "serve.pkl"
299

3010
logger = logging.getLogger(__name__)
3111

3212

3313
def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool = False):
34-
"""Placeholder docstring"""
14+
"""Capture dependencies and print output."""
15+
print(f"Capturing dependencies: {dependencies}, work_dir: {work_dir}, capture_all: {capture_all}")
16+
3517
path = work_dir.joinpath("requirements.txt")
3618
if "auto" in dependencies and dependencies["auto"]:
3719
command = [
@@ -45,6 +27,8 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool =
4527

4628
if capture_all:
4729
command.append("--capture_all")
30+
31+
print(f"Running subprocess with command: {command}")
4832

4933
subprocess.run(
5034
command,
@@ -55,62 +39,83 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool =
5539
with open(path, "r") as f:
5640
autodetect_depedencies = f.read().splitlines()
5741
autodetect_depedencies.append("sagemaker[huggingface]>=2.199")
42+
print(f"Auto-detected dependencies: {autodetect_depedencies}")
5843
else:
5944
autodetect_depedencies = ["sagemaker[huggingface]>=2.199"]
45+
print(f"No auto-detection, using default dependencies: {autodetect_depedencies}")
6046

6147
module_version_dict = _parse_dependency_list(autodetect_depedencies)
48+
print(f"Parsed auto-detected dependencies: {module_version_dict}")
6249

6350
if "requirements" in dependencies:
6451
module_version_dict = _process_customer_provided_requirements(
6552
requirements_file=dependencies["requirements"], module_version_dict=module_version_dict
6653
)
54+
print(f"After processing customer-provided requirements: {module_version_dict}")
55+
6756
if "custom" in dependencies:
6857
module_version_dict = _process_custom_dependencies(
6958
custom_dependencies=dependencies.get("custom"), module_version_dict=module_version_dict
7059
)
60+
print(f"After processing custom dependencies: {module_version_dict}")
61+
7162
with open(path, "w") as f:
7263
for module, version in module_version_dict.items():
7364
f.write(f"{module}{version}\n")
65+
print(f"Final dependencies written to {path}")
7466

7567

7668
def _process_custom_dependencies(custom_dependencies: list, module_version_dict: dict):
77-
"""Placeholder docstring"""
69+
"""Process custom dependencies and print output."""
70+
print(f"Processing custom dependencies: {custom_dependencies}")
71+
7872
custom_module_version_dict = _parse_dependency_list(custom_dependencies)
73+
print(f"Parsed custom dependencies: {custom_module_version_dict}")
74+
7975
module_version_dict.update(custom_module_version_dict)
76+
print(f"Updated module_version_dict with custom dependencies: {module_version_dict}")
77+
8078
return module_version_dict
8179

8280

8381
def _process_customer_provided_requirements(requirements_file: str, module_version_dict: dict):
84-
"""Placeholder docstring"""
82+
"""Process customer-provided requirements and print output."""
83+
print(f"Processing customer-provided requirements from file: {requirements_file}")
84+
8585
requirements_file = Path(requirements_file)
8686
if not requirements_file.is_file() or not _is_valid_requirement_file(requirements_file):
8787
raise Exception(f"Path: {requirements_file} to requirements.txt doesn't exist")
88+
8889
logger.debug("Packaging provided requirements.txt from %s", requirements_file)
8990
with open(requirements_file, "r") as f:
9091
custom_dependencies = f.read().splitlines()
92+
93+
print(f"Customer-provided dependencies: {custom_dependencies}")
9194

9295
module_version_dict.update(_parse_dependency_list(custom_dependencies))
96+
print(f"Updated module_version_dict with customer-provided requirements: {module_version_dict}")
97+
9398
return module_version_dict
9499

95100

96101
def _is_valid_requirement_file(path):
97-
"""Placeholder docstring"""
98-
# In the future, we can also check the if the content of customer provided file has valid format
102+
"""Check if the requirements file is valid and print result."""
103+
print(f"Validating requirement file: {path}")
104+
99105
for suffix in _SUPPORTED_SUFFIXES:
100106
if path.name.endswith(suffix):
107+
print(f"File {path} is valid with suffix {suffix}")
101108
return True
109+
110+
print(f"File {path} is not valid")
102111
return False
103112

104113

105114
def _parse_dependency_list(depedency_list: list) -> dict:
106-
"""Placeholder docstring"""
107-
108-
# Divide a string into 2 part, first part is the module name
109-
# and second part is its version constraint or the url
110-
# checkout tests/unit/sagemaker/serve/detector/test_dependency_manager.py
111-
# for examples
115+
"""Parse the dependency list and print output."""
116+
print(f"Parsing dependency list: {depedency_list}")
117+
112118
pattern = r"^([\w.-]+)(@[^,\n]+|((?:[<>=!~]=?[\w.*-]+,?)+)?)$"
113-
114119
module_version_dict = {}
115120

116121
for dependency in depedency_list:
@@ -119,10 +124,10 @@ def _parse_dependency_list(depedency_list: list) -> dict:
119124
match = re.match(pattern, dependency)
120125
if match:
121126
package = match.group(1)
122-
# Group 2 is either a URL or version constraint, if present
123127
url_or_version = match.group(2) if match.group(2) else ""
124128
module_version_dict.update({package: url_or_version})
125129
else:
126130
module_version_dict.update({dependency: ""})
127-
131+
132+
print(f"Parsed module_version_dict: {module_version_dict}")
128133
return module_version_dict

0 commit comments

Comments
 (0)