@@ -265,6 +265,8 @@ def test_jumpstart_model_register(setup):
265265
266266 response = predictor .predict ("hello world!" )
267267
268+ predictor .delete_predictor ()
269+
268270 assert response is not None
269271
270272
@@ -291,3 +293,59 @@ def test_proprietary_jumpstart_model(setup):
291293 response = predictor .predict (payload )
292294
293295 assert response is not None
296+
297+
298+ @pytest .mark .skipif (
299+ True ,
300+ reason = "Only enable if test account is subscribed to the proprietary model" ,
301+ )
302+ def test_register_proprietary_jumpstart_model (setup ):
303+
304+ model_id = "ai21-jurassic-2-light"
305+
306+ model = JumpStartModel (
307+ model_id = model_id ,
308+ model_version = "2.0.004" ,
309+ role = get_sm_session ().get_caller_identity_arn (),
310+ sagemaker_session = get_sm_session (),
311+ )
312+ model_package = model .register ()
313+
314+ predictor = model_package .deploy (
315+ tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}]
316+ )
317+ payload = {"prompt" : "To be, or" , "maxTokens" : 4 , "temperature" : 0 , "numResults" : 1 }
318+
319+ response = predictor .predict (payload )
320+
321+ predictor .delete_predictor ()
322+
323+ assert response is not None
324+
325+
326+ @pytest .mark .skipif (
327+ True ,
328+ reason = "Only enable if test account is subscribed to the proprietary model" ,
329+ )
330+ def test_register_gated_jumpstart_model (setup ):
331+
332+ model_id = "meta-textgenerationneuron-llama-2-7b"
333+ model = JumpStartModel (
334+ model_id = model_id ,
335+ model_version = "1.1.0" ,
336+ role = get_sm_session ().get_caller_identity_arn (),
337+ sagemaker_session = get_sm_session (),
338+ )
339+ model_package = model .register (accept_eula = True )
340+
341+ predictor = model_package .deploy (
342+ tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
343+ accept_eula = True ,
344+ )
345+ payload = {"prompt" : "To be, or" , "maxTokens" : 4 , "temperature" : 0 , "numResults" : 1 }
346+
347+ response = predictor .predict (payload )
348+
349+ predictor .delete_predictor ()
350+
351+ assert response is not None
0 commit comments