@@ -2175,48 +2175,46 @@ class TestAcceptEulaModelAccessConfig(TestCase):
2175
2175
MOCK_PUBLIC_MODEL_ID = "mock_public_model_id"
2176
2176
MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [
2177
2177
{
2178
- ' ChannelName' : ' draft_model' ,
2179
- ' S3DataSource' : {
2180
- ' CompressionType' : ' None' ,
2181
- ' S3DataType' : ' S3Prefix' ,
2182
- ' S3Uri' : ' s3://jumpstart_bucket/path/to/public/resources/'
2178
+ " ChannelName" : " draft_model" ,
2179
+ " S3DataSource" : {
2180
+ " CompressionType" : " None" ,
2181
+ " S3DataType" : " S3Prefix" ,
2182
+ " S3Uri" : " s3://jumpstart_bucket/path/to/public/resources/" ,
2183
2183
},
2184
- ' HostingEulaKey' : None
2184
+ " HostingEulaKey" : None ,
2185
2185
}
2186
2186
]
2187
2187
MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [
2188
2188
{
2189
- ' ChannelName' : ' draft_model' ,
2190
- ' S3DataSource' : {
2191
- ' CompressionType' : ' None' ,
2192
- ' S3DataType' : ' S3Prefix' ,
2193
- ' S3Uri' : ' s3://jumpstart_bucket/path/to/public/resources/'
2194
- }
2189
+ " ChannelName" : " draft_model" ,
2190
+ " S3DataSource" : {
2191
+ " CompressionType" : " None" ,
2192
+ " S3DataType" : " S3Prefix" ,
2193
+ " S3Uri" : " s3://jumpstart_bucket/path/to/public/resources/" ,
2194
+ },
2195
2195
}
2196
2196
]
2197
2197
MOCK_GATED_MODEL_ID = "mock_gated_model_id"
2198
2198
MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [
2199
2199
{
2200
- ' ChannelName' : ' draft_model' ,
2201
- ' S3DataSource' : {
2202
- ' CompressionType' : ' None' ,
2203
- ' S3DataType' : ' S3Prefix' ,
2204
- ' S3Uri' : ' s3://jumpstart_bucket/path/to/gated/resources/'
2200
+ " ChannelName" : " draft_model" ,
2201
+ " S3DataSource" : {
2202
+ " CompressionType" : " None" ,
2203
+ " S3DataType" : " S3Prefix" ,
2204
+ " S3Uri" : " s3://jumpstart_bucket/path/to/gated/resources/" ,
2205
2205
},
2206
- ' HostingEulaKey' : "fmhMetadata/eula/llama3_2Eula.txt"
2206
+ " HostingEulaKey" : "fmhMetadata/eula/llama3_2Eula.txt" ,
2207
2207
}
2208
2208
]
2209
2209
MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [
2210
2210
{
2211
- 'ChannelName' : 'draft_model' ,
2212
- 'S3DataSource' : {
2213
- 'CompressionType' : 'None' ,
2214
- 'S3DataType' : 'S3Prefix' ,
2215
- 'S3Uri' : 's3://jumpstart_bucket/path/to/gated/resources/' ,
2216
- 'ModelAccessConfig' : {
2217
- "AcceptEula" : True
2218
- }
2219
- }
2211
+ "ChannelName" : "draft_model" ,
2212
+ "S3DataSource" : {
2213
+ "CompressionType" : "None" ,
2214
+ "S3DataType" : "S3Prefix" ,
2215
+ "S3Uri" : "s3://jumpstart_bucket/path/to/gated/resources/" ,
2216
+ "ModelAccessConfig" : {"AcceptEula" : True },
2217
+ },
2220
2218
}
2221
2219
]
2222
2220
@@ -2232,14 +2230,17 @@ def test_public_additional_model_data_source_should_pass_through(self):
2232
2230
)
2233
2231
2234
2232
# THEN
2235
- assert additional_model_data_sources == self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2233
+ assert (
2234
+ additional_model_data_sources
2235
+ == self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2236
+ )
2236
2237
2237
2238
def test_multiple_public_additional_model_data_source_should_pass_through_both (self ):
2238
2239
# WHERE / WHEN
2239
2240
additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
2240
2241
model_data_sources = (
2241
- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
2242
2242
self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2243
+ + self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2243
2244
),
2244
2245
model_access_configs = None ,
2245
2246
model_id = self .MOCK_PUBLIC_MODEL_ID ,
@@ -2248,23 +2249,24 @@ def test_multiple_public_additional_model_data_source_should_pass_through_both(s
2248
2249
2249
2250
# THEN
2250
2251
assert additional_model_data_sources == (
2251
- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
2252
2252
self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2253
+ + self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2253
2254
)
2254
2255
2255
2256
def test_public_additional_model_data_source_with_model_access_config_should_ignored_it (self ):
2256
2257
# WHERE / WHEN
2257
2258
additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
2258
2259
model_data_sources = self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ,
2259
- model_access_configs = {
2260
- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True )
2261
- },
2260
+ model_access_configs = {self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True )},
2262
2261
model_id = self .MOCK_GATED_MODEL_ID ,
2263
2262
region = JUMPSTART_DEFAULT_REGION_NAME ,
2264
2263
)
2265
2264
2266
2265
# THEN
2267
- assert additional_model_data_sources == self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2266
+ assert (
2267
+ additional_model_data_sources
2268
+ == self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2269
+ )
2268
2270
2269
2271
def test_no_additional_model_data_source_should_pass_through (self ):
2270
2272
# WHERE / WHEN
@@ -2284,62 +2286,65 @@ def test_gated_additional_model_data_source_should_accept_it(self):
2284
2286
# WHERE / WHEN
2285
2287
additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
2286
2288
model_data_sources = self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ,
2287
- model_access_configs = {
2288
- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True )
2289
- },
2289
+ model_access_configs = {self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True )},
2290
2290
model_id = self .MOCK_GATED_MODEL_ID ,
2291
2291
region = JUMPSTART_DEFAULT_REGION_NAME ,
2292
2292
)
2293
2293
2294
2294
# THEN
2295
- assert additional_model_data_sources == self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2295
+ assert (
2296
+ additional_model_data_sources
2297
+ == self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2298
+ )
2296
2299
2297
2300
def test_multiple_gated_additional_model_data_source_should_accept_both (self ):
2298
2301
# WHERE / WHEN
2299
2302
additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
2300
2303
model_data_sources = (
2301
- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
2302
2304
self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2305
+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2303
2306
),
2304
2307
model_access_configs = {
2305
- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True ),
2306
- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True )
2308
+ self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True ),
2309
+ self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True ),
2307
2310
},
2308
2311
model_id = self .MOCK_GATED_MODEL_ID ,
2309
2312
region = JUMPSTART_DEFAULT_REGION_NAME ,
2310
2313
)
2311
2314
2312
2315
# THEN
2313
2316
assert additional_model_data_sources == (
2314
- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
2315
2317
self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2318
+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2316
2319
)
2317
2320
2318
2321
# Mixed Positive Cases
2319
2322
2320
- def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other (self ):
2323
+ def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other (
2324
+ self ,
2325
+ ):
2321
2326
# WHERE / WHEN
2322
2327
additional_model_data_sources = utils ._add_model_access_configs_to_model_data_sources (
2323
2328
model_data_sources = (
2324
- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
2325
- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2329
+ self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2330
+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2326
2331
),
2327
- model_access_configs = {
2328
- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = True )
2329
- },
2332
+ model_access_configs = {self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = True )},
2330
2333
model_id = self .MOCK_GATED_MODEL_ID ,
2331
2334
region = JUMPSTART_DEFAULT_REGION_NAME ,
2332
2335
)
2333
2336
2334
2337
# THEN
2335
2338
assert additional_model_data_sources == (
2336
- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
2337
- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2339
+ self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2340
+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2338
2341
)
2339
2342
2340
2343
# Test Gated Negative Tests
2341
2344
2342
- def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error (self ):
2345
+ def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error (
2346
+ self ,
2347
+ ):
2343
2348
# WHERE / WHEN / THEN
2344
2349
with self .assertRaises (ValueError ):
2345
2350
utils ._add_model_access_configs_to_model_data_sources (
@@ -2354,33 +2359,37 @@ def test_multiple_mixed_additional_no_model_data_source_should_raise_value_error
2354
2359
with self .assertRaises (ValueError ):
2355
2360
utils ._add_model_access_configs_to_model_data_sources (
2356
2361
model_data_sources = (
2357
- self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
2358
- self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2362
+ self .MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2363
+ + self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2359
2364
),
2360
2365
model_access_configs = None ,
2361
2366
model_id = self .MOCK_GATED_MODEL_ID ,
2362
2367
region = JUMPSTART_DEFAULT_REGION_NAME ,
2363
2368
)
2364
2369
2365
- def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error (self ):
2370
+ def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error (
2371
+ self ,
2372
+ ):
2366
2373
# WHERE / WHEN / THEN
2367
2374
with self .assertRaises (ValueError ):
2368
2375
utils ._add_model_access_configs_to_model_data_sources (
2369
2376
model_data_sources = self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ,
2370
2377
model_access_configs = {
2371
- self .MOCK_PUBLIC_MODEL_ID :ModelAccessConfig (accept_eula = True )
2378
+ self .MOCK_PUBLIC_MODEL_ID : ModelAccessConfig (accept_eula = True )
2372
2379
},
2373
2380
model_id = self .MOCK_GATED_MODEL_ID ,
2374
2381
region = JUMPSTART_DEFAULT_REGION_NAME ,
2375
2382
)
2376
2383
2377
- def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error (self ):
2384
+ def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error (
2385
+ self ,
2386
+ ):
2378
2387
# WHERE / WHEN / THEN
2379
2388
with self .assertRaises (ValueError ):
2380
2389
utils ._add_model_access_configs_to_model_data_sources (
2381
2390
model_data_sources = self .MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ,
2382
2391
model_access_configs = {
2383
- self .MOCK_GATED_MODEL_ID :ModelAccessConfig (accept_eula = False )
2392
+ self .MOCK_GATED_MODEL_ID : ModelAccessConfig (accept_eula = False )
2384
2393
},
2385
2394
model_id = self .MOCK_GATED_MODEL_ID ,
2386
2395
region = JUMPSTART_DEFAULT_REGION_NAME ,
0 commit comments