@@ -34,45 +34,41 @@ class TestTensorFlowServingNumpy2:
3434 def test_tensorflow_serving_validation_with_numpy2 (self , sagemaker_session ):
3535 """Test TensorFlow Serving validation works with numpy 2.0."""
3636 logger .info (f"Testing TensorFlow Serving validation with numpy { np .__version__ } " )
37-
37+
3838 # Create a simple schema builder with numpy 2.0 arrays
3939 input_data = np .array ([[1.0 , 2.0 , 3.0 ]], dtype = np .float32 )
4040 output_data = np .array ([4.0 ], dtype = np .float32 )
41-
42- schema_builder = SchemaBuilder (
43- sample_input = input_data ,
44- sample_output = output_data
45- )
46-
41+
42+ schema_builder = SchemaBuilder (sample_input = input_data , sample_output = output_data )
43+
4744 # Test without MLflow model - should raise validation error
4845 model_builder = ModelBuilder (
4946 mode = Mode .SAGEMAKER_ENDPOINT ,
5047 model_server = ModelServer .TENSORFLOW_SERVING ,
5148 schema_builder = schema_builder ,
5249 sagemaker_session = sagemaker_session ,
5350 )
54-
55- with pytest .raises (ValueError , match = "Tensorflow Serving is currently only supported for mlflow models" ):
51+
52+ with pytest .raises (
53+ ValueError , match = "Tensorflow Serving is currently only supported for mlflow models"
54+ ):
5655 model_builder ._validate_for_tensorflow_serving ()
57-
56+
5857 logger .info ("TensorFlow Serving validation test passed" )
5958
6059 def test_tensorflow_serving_with_sample_mlflow_model (self , sagemaker_session ):
6160 """Test TensorFlow Serving builder initialization with sample MLflow model."""
6261 logger .info ("Testing TensorFlow Serving with sample MLflow model" )
63-
62+
6463 # Use constant MLflow model structure from test data
6564 mlflow_model_dir = os .path .join (DATA_DIR , "serve_resources" , "mlflow" , "tensorflow_numpy2" )
66-
65+
6766 # Create schema builder with numpy 2.0 arrays
6867 input_data = np .array ([[1.0 , 2.0 , 3.0 , 4.0 ]], dtype = np .float32 )
6968 output_data = np .array ([5.0 ], dtype = np .float32 )
70-
71- schema_builder = SchemaBuilder (
72- sample_input = input_data ,
73- sample_output = output_data
74- )
75-
69+
70+ schema_builder = SchemaBuilder (sample_input = input_data , sample_output = output_data )
71+
7672 # Create ModelBuilder - this should not raise validation errors
7773 model_builder = ModelBuilder (
7874 mode = Mode .SAGEMAKER_ENDPOINT ,
@@ -82,18 +78,18 @@ def test_tensorflow_serving_with_sample_mlflow_model(self, sagemaker_session):
8278 model_metadata = {"MLFLOW_MODEL_PATH" : mlflow_model_dir },
8379 role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" ,
8480 )
85-
81+
8682 # Initialize MLflow handling to set _is_mlflow_model flag
8783 model_builder ._handle_mlflow_input ()
88-
84+
8985 # Test validation passes
9086 model_builder ._validate_for_tensorflow_serving ()
9187 logger .info ("TensorFlow Serving with sample MLflow model test passed" )
9288
9389 def test_numpy2_custom_payload_translators (self ):
9490 """Test custom payload translators work with numpy 2.0."""
9591 logger .info (f"Testing custom payload translators with numpy { np .__version__ } " )
96-
92+
9793 class Numpy2RequestTranslator (CustomPayloadTranslator ):
9894 def serialize_payload_to_bytes (self , payload : object ) -> bytes :
9995 buffer = io .BytesIO ()
@@ -115,49 +111,46 @@ def deserialize_payload_from_stream(self, stream) -> object:
115111 # Test data
116112 test_input = np .array ([[1.0 , 2.0 , 3.0 ]], dtype = np .float32 )
117113 test_output = np .array ([4.0 ], dtype = np .float32 )
118-
114+
119115 # Create translators
120116 request_translator = Numpy2RequestTranslator ()
121117 response_translator = Numpy2ResponseTranslator ()
122-
118+
123119 # Test request translator
124120 serialized_input = request_translator .serialize_payload_to_bytes (test_input )
125121 assert isinstance (serialized_input , bytes )
126-
122+
127123 deserialized_input = request_translator .deserialize_payload_from_stream (
128124 io .BytesIO (serialized_input )
129125 )
130126 np .testing .assert_array_equal (test_input , deserialized_input )
131-
127+
132128 # Test response translator
133129 serialized_output = response_translator .serialize_payload_to_bytes (test_output )
134130 assert isinstance (serialized_output , bytes )
135-
131+
136132 deserialized_output = response_translator .deserialize_payload_from_stream (
137133 io .BytesIO (serialized_output )
138134 )
139135 np .testing .assert_array_equal (test_output , deserialized_output )
140-
136+
141137 logger .info ("Custom payload translators test passed" )
142138
143139 def test_numpy2_schema_builder_creation (self ):
144140 """Test SchemaBuilder creation with numpy 2.0 arrays."""
145141 logger .info (f"Testing SchemaBuilder with numpy { np .__version__ } " )
146-
142+
147143 # Create test data with numpy 2.0
148144 input_data = np .array ([[1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]], dtype = np .float32 )
149145 output_data = np .array ([10.0 ], dtype = np .float32 )
150-
146+
151147 # Create SchemaBuilder
152- schema_builder = SchemaBuilder (
153- sample_input = input_data ,
154- sample_output = output_data
155- )
156-
148+ schema_builder = SchemaBuilder (sample_input = input_data , sample_output = output_data )
149+
157150 # Verify schema builder properties
158151 assert schema_builder .sample_input is not None
159152 assert schema_builder .sample_output is not None
160-
153+
161154 # Test with custom translators
162155 class TestTranslator (CustomPayloadTranslator ):
163156 def serialize_payload_to_bytes (self , payload : object ) -> bytes :
@@ -167,42 +160,42 @@ def serialize_payload_to_bytes(self, payload: object) -> bytes:
167160
168161 def deserialize_payload_from_stream (self , stream ) -> object :
169162 return np .load (io .BytesIO (stream .read ()), allow_pickle = False )
170-
163+
171164 translator = TestTranslator ()
172165 schema_builder_with_translator = SchemaBuilder (
173166 sample_input = input_data ,
174167 sample_output = output_data ,
175168 input_translator = translator ,
176- output_translator = translator
169+ output_translator = translator ,
177170 )
178-
171+
179172 assert schema_builder_with_translator .custom_input_translator is not None
180173 assert schema_builder_with_translator .custom_output_translator is not None
181-
174+
182175 logger .info ("SchemaBuilder creation test passed" )
183176
184177 def test_numpy2_basic_operations (self ):
185178 """Test basic numpy 2.0 operations used in TensorFlow Serving."""
186179 logger .info (f"Testing basic numpy 2.0 operations. Version: { np .__version__ } " )
187-
180+
188181 # Test array creation
189182 arr = np .array ([1.0 , 2.0 , 3.0 , 4.0 ], dtype = np .float32 )
190183 assert arr .dtype == np .float32
191184 assert arr .shape == (4 ,)
192-
185+
193186 # Test array operations
194187 arr_2d = np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = np .float32 )
195188 assert arr_2d .shape == (2 , 2 )
196-
189+
197190 # Test serialization without pickle (numpy 2.0 safe)
198191 buffer = io .BytesIO ()
199192 np .save (buffer , arr_2d , allow_pickle = False )
200193 buffer .seek (0 )
201194 loaded_arr = np .load (buffer , allow_pickle = False )
202-
195+
203196 np .testing .assert_array_equal (arr_2d , loaded_arr )
204-
197+
205198 # Test dtype preservation
206199 assert loaded_arr .dtype == np .float32
207-
200+
208201 logger .info ("Basic numpy 2.0 operations test passed" )
0 commit comments