@@ -2175,48 +2175,46 @@ class TestAcceptEulaModelAccessConfig(TestCase):
21752175 MOCK_PUBLIC_MODEL_ID = "mock_public_model_id"
21762176 MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [
21772177 {
2178- ' ChannelName' : ' draft_model' ,
2179- ' S3DataSource' : {
2180- ' CompressionType' : ' None' ,
2181- ' S3DataType' : ' S3Prefix' ,
2182- ' S3Uri' : ' s3://jumpstart_bucket/path/to/public/resources/'
2178+ " ChannelName" : " draft_model" ,
2179+ " S3DataSource" : {
2180+ " CompressionType" : " None" ,
2181+ " S3DataType" : " S3Prefix" ,
2182+ " S3Uri" : " s3://jumpstart_bucket/path/to/public/resources/" ,
21832183 },
2184- ' HostingEulaKey' : None
2184+ " HostingEulaKey" : None ,
21852185 }
21862186 ]
21872187 MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [
21882188 {
2189- ' ChannelName' : ' draft_model' ,
2190- ' S3DataSource' : {
2191- ' CompressionType' : ' None' ,
2192- ' S3DataType' : ' S3Prefix' ,
2193- ' S3Uri' : ' s3://jumpstart_bucket/path/to/public/resources/'
2194- }
2189+ " ChannelName" : " draft_model" ,
2190+ " S3DataSource" : {
2191+ " CompressionType" : " None" ,
2192+ " S3DataType" : " S3Prefix" ,
2193+ " S3Uri" : " s3://jumpstart_bucket/path/to/public/resources/" ,
2194+ },
21952195 }
21962196 ]
21972197 MOCK_GATED_MODEL_ID = "mock_gated_model_id"
21982198 MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [
21992199 {
2200- ' ChannelName' : ' draft_model' ,
2201- ' S3DataSource' : {
2202- ' CompressionType' : ' None' ,
2203- ' S3DataType' : ' S3Prefix' ,
2204- ' S3Uri' : ' s3://jumpstart_bucket/path/to/gated/resources/'
2200+ " ChannelName" : " draft_model" ,
2201+ " S3DataSource" : {
2202+ " CompressionType" : " None" ,
2203+ " S3DataType" : " S3Prefix" ,
2204+ " S3Uri" : " s3://jumpstart_bucket/path/to/gated/resources/" ,
22052205 },
2206- ' HostingEulaKey' : "fmhMetadata/eula/llama3_2Eula.txt"
2206+ " HostingEulaKey" : "fmhMetadata/eula/llama3_2Eula.txt" ,
22072207 }
22082208 ]
22092209 MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [
22102210 {
2211- 'ChannelName' : 'draft_model' ,
2212- 'S3DataSource' : {
2213- 'CompressionType' : 'None' ,
2214- 'S3DataType' : 'S3Prefix' ,
2215- 'S3Uri' : 's3://jumpstart_bucket/path/to/gated/resources/' ,
2216- 'ModelAccessConfig' : {
2217- "AcceptEula" : True
2218- }
2219- }
2211+ "ChannelName" : "draft_model" ,
2212+ "S3DataSource" : {
2213+ "CompressionType" : "None" ,
2214+ "S3DataType" : "S3Prefix" ,
2215+ "S3Uri" : "s3://jumpstart_bucket/path/to/gated/resources/" ,
2216+ "ModelAccessConfig" : {"AcceptEula" : True },
2217+ },
22202218 }
22212219 ]
22222220
@@ -2232,14 +2230,17 @@ def test_public_additional_model_data_source_should_pass_through(self):
22322230 )
22332231
22342232 # THEN
2235- assert additional_model_data_sources == self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2233+ assert (
2234+ additional_model_data_sources
2235+ == self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2236+ )
22362237
22372238 def test_multiple_public_additional_model_data_source_should_pass_through_both (self ):
22382239 # WHERE / WHEN
22392240 additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
22402241 model_data_sources = (
2241- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
22422242 self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2243+ + self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
22432244 ),
22442245 model_access_configs = None ,
22452246 model_id = self .MOCK_PUBLIC_MODEL_ID ,
@@ -2248,23 +2249,24 @@ def test_multiple_public_additional_model_data_source_should_pass_through_both(s
22482249
22492250 # THEN
22502251 assert additional_model_data_sources == (
2251- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
22522252 self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2253+ + self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
22532254 )
22542255
22552256 def test_public_additional_model_data_source_with_model_access_config_should_ignored_it (self ):
22562257 # WHERE / WHEN
22572258 additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
22582259 model_data_sources = self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ,
2259- model_access_configs = {
2260- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True )
2261- },
2260+ model_access_configs = {self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True )},
22622261 model_id = self .MOCK_GATED_MODEL_ID ,
22632262 region = JUMPSTART_DEFAULT_REGION_NAME ,
22642263 )
22652264
22662265 # THEN
2267- assert additional_model_data_sources == self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2266+ assert (
2267+ additional_model_data_sources
2268+ == self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2269+ )
22682270
22692271 def test_no_additional_model_data_source_should_pass_through (self ):
22702272 # WHERE / WHEN
@@ -2284,62 +2286,65 @@ def test_gated_additional_model_data_source_should_accept_it(self):
22842286 # WHERE / WHEN
22852287 additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
22862288 model_data_sources = self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ,
2287- model_access_configs = {
2288- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True )
2289- },
2289+ model_access_configs = {self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True )},
22902290 model_id = self .MOCK_GATED_MODEL_ID ,
22912291 region = JUMPSTART_DEFAULT_REGION_NAME ,
22922292 )
22932293
22942294 # THEN
2295- assert additional_model_data_sources == self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2295+ assert (
2296+ additional_model_data_sources
2297+ == self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2298+ )
22962299
22972300 def test_multiple_gated_additional_model_data_source_should_accept_both (self ):
22982301 # WHERE / WHEN
22992302 additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
23002303 model_data_sources = (
2301- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
23022304 self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2305+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
23032306 ),
23042307 model_access_configs = {
2305- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True ),
2306- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True )
2308+ self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True ),
2309+ self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True ),
23072310 },
23082311 model_id = self .MOCK_GATED_MODEL_ID ,
23092312 region = JUMPSTART_DEFAULT_REGION_NAME ,
23102313 )
23112314
23122315 # THEN
23132316 assert additional_model_data_sources == (
2314- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
23152317 self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2318+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
23162319 )
23172320
23182321 # Mixed Positive Cases
23192322
2320- def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other (self ):
2323+ def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other (
2324+ self ,
2325+ ):
23212326 # WHERE / WHEN
23222327 additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
23232328 model_data_sources = (
2324- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
2325- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2329+ self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2330+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
23262331 ),
2327- model_access_configs = {
2328- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True )
2329- },
2332+ model_access_configs = {self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True )},
23302333 model_id = self .MOCK_GATED_MODEL_ID ,
23312334 region = JUMPSTART_DEFAULT_REGION_NAME ,
23322335 )
23332336
23342337 # THEN
23352338 assert additional_model_data_sources == (
2336- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
2337- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2339+ self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2340+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
23382341 )
23392342
23402343 # Test Gated Negative Tests
23412344
2342- def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error (self ):
2345+ def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error (
2346+ self ,
2347+ ):
23432348 # WHERE / WHEN / THEN
23442349 with self .assertRaises (ValueError ):
23452350 utils ._add_model_access_configs_to_model_data_sources (
@@ -2354,33 +2359,37 @@ def test_multiple_mixed_additional_no_model_data_source_should_raise_value_error
23542359 with self .assertRaises (ValueError ):
23552360 utils ._add_model_access_configs_to_model_data_sources (
23562361 model_data_sources = (
2357- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
2358- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2362+ self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2363+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
23592364 ),
23602365 model_access_configs = None ,
23612366 model_id = self .MOCK_GATED_MODEL_ID ,
23622367 region = JUMPSTART_DEFAULT_REGION_NAME ,
23632368 )
23642369
2365- def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error (self ):
2370+ def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error (
2371+ self ,
2372+ ):
23662373 # WHERE / WHEN / THEN
23672374 with self .assertRaises (ValueError ):
23682375 utils ._add_model_access_configs_to_model_data_sources (
23692376 model_data_sources = self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ,
23702377 model_access_configs = {
2371- self .MOCK_PUBLIC_MODEL_ID :ModelAccessConfig (accept_eula = True )
2378+ self .MOCK_PUBLIC_MODEL_ID : ModelAccessConfig (accept_eula = True )
23722379 },
23732380 model_id = self .MOCK_GATED_MODEL_ID ,
23742381 region = JUMPSTART_DEFAULT_REGION_NAME ,
23752382 )
23762383
2377- def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error (self ):
2384+ def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error (
2385+ self ,
2386+ ):
23782387 # WHERE / WHEN / THEN
23792388 with self .assertRaises (ValueError ):
23802389 utils ._add_model_access_configs_to_model_data_sources (
23812390 model_data_sources = self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ,
23822391 model_access_configs = {
2383- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = False )
2392+ self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = False )
23842393 },
23852394 model_id = self .MOCK_GATED_MODEL_ID ,
23862395 region = JUMPSTART_DEFAULT_REGION_NAME ,
0 commit comments