@@ -1219,3 +1219,91 @@ async def test_generate(self):
12191219 response .prompt_length , 40960
12201220 ) # If not long enough, please add more files to prompt
12211221 self .assertGreater (response .logprobs .shape [0 ], 1000 )
1222+
1223+
1224+ class TestTinkerAPI (RayUnittestBaseAysnc ):
1225+ """Test the Tinker API integration with the vLLM engine."""
1226+
1227+ def setUp (self ):
1228+ self .config = get_template_config ()
1229+ self .config .mode = "explore"
1230+ self .config .model .model_path = get_model_path ()
1231+ self .config .explorer .rollout_model .engine_type = "vllm"
1232+ self .config .explorer .rollout_model .engine_num = 1
1233+ self .config .explorer .rollout_model .tensor_parallel_size = 1
1234+ self .config .explorer .rollout_model .chat_template = CHAT_TEMPLATE
1235+ self .config .explorer .rollout_model .enable_openai_api = True
1236+
1237+ self .config .check_and_update ()
1238+ self .engines , self .auxiliary_engines = create_inference_models (self .config )
1239+ self .model_wrapper = ModelWrapper (self .engines [0 ], engine_type = "vllm" , enable_history = True )
1240+
1241+ async def test_tinker_api (self ):
1242+ from tinker import types
1243+ from transformers import AutoTokenizer
1244+
1245+ engine = self .engines [0 ]
1246+ tokenizer = AutoTokenizer .from_pretrained (self .config .model .model_path )
1247+ messages = [
1248+ {"role" : "system" , "content" : "You are a helpful assistant." },
1249+ {"role" : "user" , "content" : "What is your name?" },
1250+ ]
1251+ result_dict = tokenizer .apply_chat_template (
1252+ messages ,
1253+ chat_template = CHAT_TEMPLATE ,
1254+ add_generation_prompt = False ,
1255+ padding = False ,
1256+ truncation = True ,
1257+ return_tensors = "pt" ,
1258+ add_special_tokens = False ,
1259+ return_assistant_tokens_mask = True ,
1260+ return_dict = True ,
1261+ )
1262+ prompt = types .ModelInput .from_ints (
1263+ result_dict ["input_ids" ][0 ].tolist (),
1264+ )
1265+ # sample api without prompt logprobs
1266+ num_samples = 4
1267+ response = await engine .sample .remote (
1268+ prompt = prompt ,
1269+ num_samples = num_samples ,
1270+ sampling_params = types .SamplingParams (temperature = 0.7 ), # no limit on length
1271+ )
1272+ self .assertEqual (len (response .sequences ), num_samples )
1273+ for sequence in response .sequences :
1274+ self .assertEqual (len (sequence .tokens ), len (sequence .logprobs ))
1275+ self .assertEqual (sequence .stop_reason , "stop" )
1276+ self .assertIsNone (response .prompt_logprobs )
1277+ self .assertIsNone (response .topk_prompt_logprobs )
1278+ # sample api with prompt logprobs
1279+ num_samples = 2
1280+ topk_prompt_logprobs = 3
1281+ response = await engine .sample .remote (
1282+ prompt = prompt ,
1283+ num_samples = num_samples ,
1284+ sampling_params = types .SamplingParams (temperature = 0.7 , max_tokens = 8 ),
1285+ include_prompt_logprobs = True ,
1286+ topk_prompt_logprobs = topk_prompt_logprobs ,
1287+ )
1288+ self .assertEqual (len (response .sequences ), num_samples )
1289+ for sequence in response .sequences :
1290+ self .assertEqual (len (sequence .tokens ), len (sequence .logprobs ))
1291+ self .assertEqual (sequence .stop_reason , "length" )
1292+ self .assertEqual (len (response .prompt_logprobs ), len (prompt .to_ints ()))
1293+ self .assertIsNone (response .prompt_logprobs [0 ])
1294+ self .assertEqual (len (response .topk_prompt_logprobs ), len (prompt .to_ints ()))
1295+ self .assertIsNone (response .topk_prompt_logprobs [0 ])
1296+ for topk_logprobs in response .topk_prompt_logprobs [1 :]:
1297+ self .assertIsNotNone (topk_logprobs )
1298+ self .assertEqual (len (topk_logprobs ), topk_prompt_logprobs )
1299+ # compute_logprob api
1300+ response = await engine .sample .remote (
1301+ prompt = prompt ,
1302+ num_samples = 1 ,
1303+ sampling_params = types .SamplingParams (max_tokens = 1 ),
1304+ include_prompt_logprobs = True ,
1305+ )
1306+ self .assertEqual (len (response .sequences ), 1 )
1307+ self .assertEqual (response .sequences [0 ].stop_reason , "length" )
1308+ self .assertEqual (len (prompt .to_ints ()), len (response .prompt_logprobs ))
1309+ self .assertIsNone (response .topk_prompt_logprobs )
0 commit comments