Skip to content

Commit a04689f

Browse files
author
Roja Reddy Sareddy
committed
numpy fixes
1 parent 330fdb1 commit a04689f

File tree

8 files changed

+242
-0
lines changed

8 files changed

+242
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
artifact_path: model
2+
flavors:
3+
python_function:
4+
env:
5+
conda: conda.yaml
6+
virtualenv: python_env.yaml
7+
loader_module: mlflow.tensorflow
8+
python_version: 3.10.0
9+
tensorflow:
10+
saved_model_dir: tf2model
11+
model_type: tf2-module
12+
mlflow_version: 2.20.3
13+
model_uuid: test-uuid-numpy2
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
channels:
2+
- conda-forge
3+
dependencies:
4+
- python=3.10
5+
- pip
6+
- pip:
7+
- numpy>=2.0.0
8+
- tensorflow==2.19.0
9+
- scikit-learn
10+
- mlflow
11+
name: mlflow-env
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensorflow.keras
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tf
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
python: 3.10.0
2+
build_dependencies:
3+
- pip
4+
dependencies:
5+
- -r requirements.txt
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
numpy>=2.0.0
2+
tensorflow==2.19.0
3+
scikit-learn
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Simple integration test for TensorFlow Serving builder with numpy 2.0 compatibility."""
14+
15+
from __future__ import absolute_import
16+
17+
import pytest
18+
import io
19+
import os
20+
import numpy as np
21+
import logging
22+
from tests.integ import DATA_DIR
23+
24+
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
25+
from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator
26+
from sagemaker.serve.utils.types import ModelServer
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class TestTensorFlowServingNumpy2:
32+
"""Simple integration tests for TensorFlow Serving with numpy 2.0."""
33+
34+
def test_tensorflow_serving_validation_with_numpy2(self, sagemaker_session):
35+
"""Test TensorFlow Serving validation works with numpy 2.0."""
36+
logger.info(f"Testing TensorFlow Serving validation with numpy {np.__version__}")
37+
38+
# Create a simple schema builder with numpy 2.0 arrays
39+
input_data = np.array([[1.0, 2.0, 3.0]], dtype=np.float32)
40+
output_data = np.array([4.0], dtype=np.float32)
41+
42+
schema_builder = SchemaBuilder(
43+
sample_input=input_data,
44+
sample_output=output_data
45+
)
46+
47+
# Test without MLflow model - should raise validation error
48+
model_builder = ModelBuilder(
49+
mode=Mode.SAGEMAKER_ENDPOINT,
50+
model_server=ModelServer.TENSORFLOW_SERVING,
51+
schema_builder=schema_builder,
52+
sagemaker_session=sagemaker_session,
53+
)
54+
55+
with pytest.raises(ValueError, match="Tensorflow Serving is currently only supported for mlflow models"):
56+
model_builder._validate_for_tensorflow_serving()
57+
58+
logger.info("TensorFlow Serving validation test passed")
59+
60+
def test_tensorflow_serving_with_sample_mlflow_model(self, sagemaker_session):
61+
"""Test TensorFlow Serving builder initialization with sample MLflow model."""
62+
logger.info("Testing TensorFlow Serving with sample MLflow model")
63+
64+
# Use constant MLflow model structure from test data
65+
mlflow_model_dir = os.path.join(DATA_DIR, "serve_resources", "mlflow", "tensorflow_numpy2")
66+
67+
# Create schema builder with numpy 2.0 arrays
68+
input_data = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
69+
output_data = np.array([5.0], dtype=np.float32)
70+
71+
schema_builder = SchemaBuilder(
72+
sample_input=input_data,
73+
sample_output=output_data
74+
)
75+
76+
# Create ModelBuilder - this should not raise validation errors
77+
model_builder = ModelBuilder(
78+
mode=Mode.SAGEMAKER_ENDPOINT,
79+
model_server=ModelServer.TENSORFLOW_SERVING,
80+
schema_builder=schema_builder,
81+
sagemaker_session=sagemaker_session,
82+
model_metadata={"MLFLOW_MODEL_PATH": mlflow_model_dir},
83+
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
84+
)
85+
86+
# Initialize MLflow handling to set _is_mlflow_model flag
87+
model_builder._handle_mlflow_input()
88+
89+
# Test validation passes
90+
model_builder._validate_for_tensorflow_serving()
91+
logger.info("TensorFlow Serving with sample MLflow model test passed")
92+
93+
def test_numpy2_custom_payload_translators(self):
94+
"""Test custom payload translators work with numpy 2.0."""
95+
logger.info(f"Testing custom payload translators with numpy {np.__version__}")
96+
97+
class Numpy2RequestTranslator(CustomPayloadTranslator):
98+
def serialize_payload_to_bytes(self, payload: object) -> bytes:
99+
buffer = io.BytesIO()
100+
np.save(buffer, payload, allow_pickle=False)
101+
return buffer.getvalue()
102+
103+
def deserialize_payload_from_stream(self, stream) -> object:
104+
return np.load(io.BytesIO(stream.read()), allow_pickle=False)
105+
106+
class Numpy2ResponseTranslator(CustomPayloadTranslator):
107+
def serialize_payload_to_bytes(self, payload: object) -> bytes:
108+
buffer = io.BytesIO()
109+
np.save(buffer, np.array(payload), allow_pickle=False)
110+
return buffer.getvalue()
111+
112+
def deserialize_payload_from_stream(self, stream) -> object:
113+
return np.load(io.BytesIO(stream.read()), allow_pickle=False)
114+
115+
# Test data
116+
test_input = np.array([[1.0, 2.0, 3.0]], dtype=np.float32)
117+
test_output = np.array([4.0], dtype=np.float32)
118+
119+
# Create translators
120+
request_translator = Numpy2RequestTranslator()
121+
response_translator = Numpy2ResponseTranslator()
122+
123+
# Test request translator
124+
serialized_input = request_translator.serialize_payload_to_bytes(test_input)
125+
assert isinstance(serialized_input, bytes)
126+
127+
deserialized_input = request_translator.deserialize_payload_from_stream(
128+
io.BytesIO(serialized_input)
129+
)
130+
np.testing.assert_array_equal(test_input, deserialized_input)
131+
132+
# Test response translator
133+
serialized_output = response_translator.serialize_payload_to_bytes(test_output)
134+
assert isinstance(serialized_output, bytes)
135+
136+
deserialized_output = response_translator.deserialize_payload_from_stream(
137+
io.BytesIO(serialized_output)
138+
)
139+
np.testing.assert_array_equal(test_output, deserialized_output)
140+
141+
logger.info("Custom payload translators test passed")
142+
143+
def test_numpy2_schema_builder_creation(self):
144+
"""Test SchemaBuilder creation with numpy 2.0 arrays."""
145+
logger.info(f"Testing SchemaBuilder with numpy {np.__version__}")
146+
147+
# Create test data with numpy 2.0
148+
input_data = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
149+
output_data = np.array([10.0], dtype=np.float32)
150+
151+
# Create SchemaBuilder
152+
schema_builder = SchemaBuilder(
153+
sample_input=input_data,
154+
sample_output=output_data
155+
)
156+
157+
# Verify schema builder properties
158+
assert schema_builder.sample_input is not None
159+
assert schema_builder.sample_output is not None
160+
161+
# Test with custom translators
162+
class TestTranslator(CustomPayloadTranslator):
163+
def serialize_payload_to_bytes(self, payload: object) -> bytes:
164+
buffer = io.BytesIO()
165+
np.save(buffer, payload, allow_pickle=False)
166+
return buffer.getvalue()
167+
168+
def deserialize_payload_from_stream(self, stream) -> object:
169+
return np.load(io.BytesIO(stream.read()), allow_pickle=False)
170+
171+
translator = TestTranslator()
172+
schema_builder_with_translator = SchemaBuilder(
173+
sample_input=input_data,
174+
sample_output=output_data,
175+
input_translator=translator,
176+
output_translator=translator
177+
)
178+
179+
assert schema_builder_with_translator.custom_input_translator is not None
180+
assert schema_builder_with_translator.custom_output_translator is not None
181+
182+
logger.info("SchemaBuilder creation test passed")
183+
184+
def test_numpy2_basic_operations(self):
185+
"""Test basic numpy 2.0 operations used in TensorFlow Serving."""
186+
logger.info(f"Testing basic numpy 2.0 operations. Version: {np.__version__}")
187+
188+
# Test array creation
189+
arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
190+
assert arr.dtype == np.float32
191+
assert arr.shape == (4,)
192+
193+
# Test array operations
194+
arr_2d = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
195+
assert arr_2d.shape == (2, 2)
196+
197+
# Test serialization without pickle (numpy 2.0 safe)
198+
buffer = io.BytesIO()
199+
np.save(buffer, arr_2d, allow_pickle=False)
200+
buffer.seek(0)
201+
loaded_arr = np.load(buffer, allow_pickle=False)
202+
203+
np.testing.assert_array_equal(arr_2d, loaded_arr)
204+
205+
# Test dtype preservation
206+
assert loaded_arr.dtype == np.float32
207+
208+
logger.info("Basic numpy 2.0 operations test passed")

0 commit comments

Comments
 (0)