55import torch
66from transformers import AutoTokenizer
77
8- from trinity . common . config import load_config
8+ from tests . tools import RayUnittestBase , get_template_config
99from trinity .common .models import create_rollout_models
1010from trinity .common .models .model import ModelWrapper
1111from trinity .common .models .utils import (
1212 tokenize_and_mask_messages_default ,
1313 tokenize_and_mask_messages_hf ,
1414)
1515
16- config_dir = os .path .join (os .path .dirname (__file__ ), "tmp" , "template_config.yaml" )
17-
1816
1917def get_model_path () -> str :
2018 path = os .environ .get ("MODEL_PATH" )
@@ -101,7 +99,12 @@ def test_generate(self):
10199 ]
102100 results = self .model_wrapper .chat (messages )
103101 self .assertEqual (len (results ), self .config .explorer .repeat_times )
104- logprobs = self .model_wrapper .logprobs (results [0 ].tokens )
102+ for result in results :
103+ input_logprobs = result .logprobs [: result .prompt_length ]
104+ output_logprobs = result .logprobs [result .prompt_length :]
105+ self .assertTrue (torch .all (input_logprobs == 0 ))
106+ self .assertTrue (torch .any (output_logprobs != 0 ))
107+ logprobs = self .model_wrapper .logprobs (results [0 ].tokens .tolist ())
105108 self .assertEqual (logprobs .shape [0 ], results [0 ].tokens .shape [0 ])
106109 messages .append (
107110 {
@@ -126,10 +129,10 @@ def test_generate(self):
126129 self .assertTrue (torch .equal (result_dict ["input_ids" ][0 ], exp .tokens ))
127130
128131
129- class TestModelWrapperSync (BaseTestModelWrapper , unittest . TestCase ):
132+ class TestModelWrapperSync (BaseTestModelWrapper , RayUnittestBase ):
130133 def setUp (self ):
131134 ray .init (ignore_reinit_error = True )
132- self .config = load_config ( config_dir )
135+ self .config = get_template_config ( )
133136 self .config .model .model_path = get_model_path ()
134137 self .config .explorer .engine_type = "vllm"
135138 self .config .explorer .engine_num = 1
@@ -138,10 +141,18 @@ def setUp(self):
138141 self .model_wrapper = ModelWrapper (self .engines [0 ], model_type = "vllm" )
139142
140143
141- class TestModelWrapperAsync (BaseTestModelWrapper , unittest .TestCase ):
144+ class TestModelWrapperAsync (BaseTestModelWrapper , RayUnittestBase ):
145+ @classmethod
146+ def setUpClass (cls ):
147+ ray .init (ignore_reinit_error = True )
148+
149+ @classmethod
150+ def tearDownClass (cls ):
151+ ray .shutdown ()
152+
142153 def setUp (self ):
143154 ray .init (ignore_reinit_error = True )
144- self .config = load_config ( config_dir )
155+ self .config = get_template_config ( )
145156 self .config .model .model_path = get_model_path ()
146157 self .config .explorer .engine_type = "vllm_async"
147158 self .config .explorer .engine_num = 1
@@ -151,6 +162,14 @@ def setUp(self):
151162
152163
153164class TestTokenizer (unittest .TestCase ):
165+ @classmethod
166+ def setUpClass (cls ):
167+ ray .init (ignore_reinit_error = True )
168+
169+ @classmethod
170+ def tearDownClass (cls ):
171+ ray .shutdown ()
172+
154173 def test_assistant_token_mask (self ):
155174 messages = [
156175 {"role" : "system" , "content" : "You are a helpful assistant." },
0 commit comments