@@ -86,7 +86,7 @@ def test_init_enable_network_isolation(sagemaker_session):
8686 num_components = 55 ,
8787 sagemaker_session = sagemaker_session ,
8888 enable_network_isolation = True ,
89- ** COMMON_ARGS
89+ ** COMMON_ARGS ,
9090 )
9191 assert pca .num_components == 55
9292 assert pca .enable_network_isolation () is True
@@ -99,7 +99,7 @@ def test_init_all_pca_hyperparameters(sagemaker_session):
9999 subtract_mean = True ,
100100 extra_components = 33 ,
101101 sagemaker_session = sagemaker_session ,
102- ** COMMON_ARGS
102+ ** COMMON_ARGS ,
103103 )
104104 assert pca .num_components == 55
105105 assert pca .algorithm_mode == "randomized"
@@ -112,7 +112,7 @@ def test_init_estimator_args(sagemaker_session):
112112 max_run = 1234 ,
113113 sagemaker_session = sagemaker_session ,
114114 data_location = "s3://some-bucket/some-key/" ,
115- ** COMMON_ARGS
115+ ** COMMON_ARGS ,
116116 )
117117 assert pca .instance_type == COMMON_ARGS ["instance_type" ]
118118 assert pca .instance_count == COMMON_ARGS ["instance_count" ]
@@ -133,7 +133,7 @@ def test_data_location_does_not_call_default_bucket(sagemaker_session):
133133 num_components = 2 ,
134134 sagemaker_session = sagemaker_session ,
135135 data_location = data_location ,
136- ** COMMON_ARGS
136+ ** COMMON_ARGS ,
137137 )
138138 assert pca .data_location == data_location
139139 assert not sagemaker_session .default_bucket .called
@@ -205,7 +205,7 @@ def test_fit_ndarray(time, sagemaker_session):
205205 num_components = 55 ,
206206 sagemaker_session = sagemaker_session ,
207207 data_location = "s3://{}/key-prefix/" .format (BUCKET_NAME ),
208- ** kwargs
208+ ** kwargs ,
209209 )
210210 train = [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ], [7.0 , 8.0 , 8.0 ], [44.0 , 55.0 , 66.0 ]]
211211 labels = [99 , 85 , 87 , 2 ]
@@ -233,7 +233,7 @@ def test_fit_pass_experiment_config(sagemaker_session):
233233 num_components = 55 ,
234234 sagemaker_session = sagemaker_session ,
235235 data_location = "s3://{}/key-prefix/" .format (BUCKET_NAME ),
236- ** kwargs
236+ ** kwargs ,
237237 )
238238 train = [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ], [7.0 , 8.0 , 8.0 ], [44.0 , 55.0 , 66.0 ]]
239239 labels = [99 , 85 , 87 , 2 ]
0 commit comments