Skip to content

Commit 5dedf88

Browse files
author
Roja Reddy Sareddy
committed
numpy fixes
1 parent a04689f commit 5dedf88

File tree

1 file changed

+38
-45
lines changed

1 file changed

+38
-45
lines changed

tests/integ/sagemaker/serve/test_tensorflow_serving_numpy2.py

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -34,45 +34,41 @@ class TestTensorFlowServingNumpy2:
3434
def test_tensorflow_serving_validation_with_numpy2(self, sagemaker_session):
3535
"""Test TensorFlow Serving validation works with numpy 2.0."""
3636
logger.info(f"Testing TensorFlow Serving validation with numpy {np.__version__}")
37-
37+
3838
# Create a simple schema builder with numpy 2.0 arrays
3939
input_data = np.array([[1.0, 2.0, 3.0]], dtype=np.float32)
4040
output_data = np.array([4.0], dtype=np.float32)
41-
42-
schema_builder = SchemaBuilder(
43-
sample_input=input_data,
44-
sample_output=output_data
45-
)
46-
41+
42+
schema_builder = SchemaBuilder(sample_input=input_data, sample_output=output_data)
43+
4744
# Test without MLflow model - should raise validation error
4845
model_builder = ModelBuilder(
4946
mode=Mode.SAGEMAKER_ENDPOINT,
5047
model_server=ModelServer.TENSORFLOW_SERVING,
5148
schema_builder=schema_builder,
5249
sagemaker_session=sagemaker_session,
5350
)
54-
55-
with pytest.raises(ValueError, match="Tensorflow Serving is currently only supported for mlflow models"):
51+
52+
with pytest.raises(
53+
ValueError, match="Tensorflow Serving is currently only supported for mlflow models"
54+
):
5655
model_builder._validate_for_tensorflow_serving()
57-
56+
5857
logger.info("TensorFlow Serving validation test passed")
5958

6059
def test_tensorflow_serving_with_sample_mlflow_model(self, sagemaker_session):
6160
"""Test TensorFlow Serving builder initialization with sample MLflow model."""
6261
logger.info("Testing TensorFlow Serving with sample MLflow model")
63-
62+
6463
# Use constant MLflow model structure from test data
6564
mlflow_model_dir = os.path.join(DATA_DIR, "serve_resources", "mlflow", "tensorflow_numpy2")
66-
65+
6766
# Create schema builder with numpy 2.0 arrays
6867
input_data = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
6968
output_data = np.array([5.0], dtype=np.float32)
70-
71-
schema_builder = SchemaBuilder(
72-
sample_input=input_data,
73-
sample_output=output_data
74-
)
75-
69+
70+
schema_builder = SchemaBuilder(sample_input=input_data, sample_output=output_data)
71+
7672
# Create ModelBuilder - this should not raise validation errors
7773
model_builder = ModelBuilder(
7874
mode=Mode.SAGEMAKER_ENDPOINT,
@@ -82,18 +78,18 @@ def test_tensorflow_serving_with_sample_mlflow_model(self, sagemaker_session):
8278
model_metadata={"MLFLOW_MODEL_PATH": mlflow_model_dir},
8379
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
8480
)
85-
81+
8682
# Initialize MLflow handling to set _is_mlflow_model flag
8783
model_builder._handle_mlflow_input()
88-
84+
8985
# Test validation passes
9086
model_builder._validate_for_tensorflow_serving()
9187
logger.info("TensorFlow Serving with sample MLflow model test passed")
9288

9389
def test_numpy2_custom_payload_translators(self):
9490
"""Test custom payload translators work with numpy 2.0."""
9591
logger.info(f"Testing custom payload translators with numpy {np.__version__}")
96-
92+
9793
class Numpy2RequestTranslator(CustomPayloadTranslator):
9894
def serialize_payload_to_bytes(self, payload: object) -> bytes:
9995
buffer = io.BytesIO()
@@ -115,49 +111,46 @@ def deserialize_payload_from_stream(self, stream) -> object:
115111
# Test data
116112
test_input = np.array([[1.0, 2.0, 3.0]], dtype=np.float32)
117113
test_output = np.array([4.0], dtype=np.float32)
118-
114+
119115
# Create translators
120116
request_translator = Numpy2RequestTranslator()
121117
response_translator = Numpy2ResponseTranslator()
122-
118+
123119
# Test request translator
124120
serialized_input = request_translator.serialize_payload_to_bytes(test_input)
125121
assert isinstance(serialized_input, bytes)
126-
122+
127123
deserialized_input = request_translator.deserialize_payload_from_stream(
128124
io.BytesIO(serialized_input)
129125
)
130126
np.testing.assert_array_equal(test_input, deserialized_input)
131-
127+
132128
# Test response translator
133129
serialized_output = response_translator.serialize_payload_to_bytes(test_output)
134130
assert isinstance(serialized_output, bytes)
135-
131+
136132
deserialized_output = response_translator.deserialize_payload_from_stream(
137133
io.BytesIO(serialized_output)
138134
)
139135
np.testing.assert_array_equal(test_output, deserialized_output)
140-
136+
141137
logger.info("Custom payload translators test passed")
142138

143139
def test_numpy2_schema_builder_creation(self):
144140
"""Test SchemaBuilder creation with numpy 2.0 arrays."""
145141
logger.info(f"Testing SchemaBuilder with numpy {np.__version__}")
146-
142+
147143
# Create test data with numpy 2.0
148144
input_data = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
149145
output_data = np.array([10.0], dtype=np.float32)
150-
146+
151147
# Create SchemaBuilder
152-
schema_builder = SchemaBuilder(
153-
sample_input=input_data,
154-
sample_output=output_data
155-
)
156-
148+
schema_builder = SchemaBuilder(sample_input=input_data, sample_output=output_data)
149+
157150
# Verify schema builder properties
158151
assert schema_builder.sample_input is not None
159152
assert schema_builder.sample_output is not None
160-
153+
161154
# Test with custom translators
162155
class TestTranslator(CustomPayloadTranslator):
163156
def serialize_payload_to_bytes(self, payload: object) -> bytes:
@@ -167,42 +160,42 @@ def serialize_payload_to_bytes(self, payload: object) -> bytes:
167160

168161
def deserialize_payload_from_stream(self, stream) -> object:
169162
return np.load(io.BytesIO(stream.read()), allow_pickle=False)
170-
163+
171164
translator = TestTranslator()
172165
schema_builder_with_translator = SchemaBuilder(
173166
sample_input=input_data,
174167
sample_output=output_data,
175168
input_translator=translator,
176-
output_translator=translator
169+
output_translator=translator,
177170
)
178-
171+
179172
assert schema_builder_with_translator.custom_input_translator is not None
180173
assert schema_builder_with_translator.custom_output_translator is not None
181-
174+
182175
logger.info("SchemaBuilder creation test passed")
183176

184177
def test_numpy2_basic_operations(self):
185178
"""Test basic numpy 2.0 operations used in TensorFlow Serving."""
186179
logger.info(f"Testing basic numpy 2.0 operations. Version: {np.__version__}")
187-
180+
188181
# Test array creation
189182
arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
190183
assert arr.dtype == np.float32
191184
assert arr.shape == (4,)
192-
185+
193186
# Test array operations
194187
arr_2d = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
195188
assert arr_2d.shape == (2, 2)
196-
189+
197190
# Test serialization without pickle (numpy 2.0 safe)
198191
buffer = io.BytesIO()
199192
np.save(buffer, arr_2d, allow_pickle=False)
200193
buffer.seek(0)
201194
loaded_arr = np.load(buffer, allow_pickle=False)
202-
195+
203196
np.testing.assert_array_equal(arr_2d, loaded_arr)
204-
197+
205198
# Test dtype preservation
206199
assert loaded_arr.dtype == np.float32
207-
200+
208201
logger.info("Basic numpy 2.0 operations test passed")

0 commit comments

Comments
 (0)