44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import logging
78from typing import Dict , Tuple
89
910import pytest
2627)
2728
2829
30+ logger : logging .Logger = logging .getLogger (__name__ )
31+ logger .setLevel (logging .INFO )
32+
33+
2934def convert_state_dict (saved_sd ):
3035 """
3136 Convert transformers state dict to vLLM format.
@@ -119,8 +124,6 @@ def validate_loaded_tensors_equals_original(
119124 For tensor parallel cases, instead of gathering sharded tensors, we shard
120125 the original tensor and compare it with the loaded shard.
121126 """
122- validation_errors = []
123-
124127 for param_name , loaded_tensor in loaded_state_dict .items ():
125128 if param_name in original_state_dict :
126129 expected_tensor = original_state_dict [param_name ]
@@ -137,22 +140,25 @@ def validate_loaded_tensors_equals_original(
137140 else :
138141 tensor_to_compare = expected_tensor .cpu ().float ()
139142
143+ # Training trainer emitted and loaded tensors are of type bfloat16,
144+ # where as a HF model loaded(expected) tensor has type float16.
140145 if not torch .allclose (
141146 loaded_tensor .float (),
142147 tensor_to_compare ,
143- rtol = 1e-5 ,
144- atol = 1e-8 ,
148+ rtol = 1e-2 ,
149+ atol = 1e-3 ,
145150 ):
146- validation_errors .append (
151+ logger .warning (
152+ f"Loaded tensor { param_name } does not equal original. \n dtype = { loaded_tensor .dtype } vs { expected_tensor .dtype } \n "
153+ f"shape= { loaded_tensor .shape } vs { expected_tensor .shape } \n , values = { copy_of_loaded_tensor } vs { copy_of_expected_tensor } "
154+ )
155+ raise ValueError (
147156 f"Loaded tensor { param_name } does not equal original "
148157 f"(shapes: loaded={ loaded_tensor .shape } , expected={ tensor_to_compare .shape } )"
149158 )
150159 else :
151160 print (f"Loaded tensor { param_name } correctly validated" )
152161
153- if validation_errors :
154- raise ValueError (f"Validation failed: { validation_errors } " )
155-
156162 print (
157163 f"Successfully validated that all { len (loaded_state_dict )} loaded tensors equal original"
158164 )
@@ -265,7 +271,7 @@ async def test_llama3_policy_update_single(setup_test):
265271
266272 # validating for single resource case.
267273 validate_loaded_tensors_equals_original (
268- loaded_state_dict , expected_state_dict , tensor_parallel_size = 0 , rank = 0
274+ loaded_state_dict , expected_state_dict , tensor_parallel_size = 0 , rank = 0
269275 )
270276 print (
271277 "Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!"
0 commit comments