Skip to content

Commit fe0d924

Browse files
committed
test is passing
1 parent 32ac7f3 commit fe0d924

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

tests/integration_tests/test_policy_update.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Dict, Tuple
89

910
import pytest
@@ -26,6 +27,10 @@
2627
)
2728

2829

30+
logger: logging.Logger = logging.getLogger(__name__)
31+
logger.setLevel(logging.INFO)
32+
33+
2934
def convert_state_dict(saved_sd):
3035
"""
3136
Convert transformers state dict to vLLM format.
@@ -119,8 +124,6 @@ def validate_loaded_tensors_equals_original(
119124
For tensor parallel cases, instead of gathering sharded tensors, we shard
120125
the original tensor and compare it with the loaded shard.
121126
"""
122-
validation_errors = []
123-
124127
for param_name, loaded_tensor in loaded_state_dict.items():
125128
if param_name in original_state_dict:
126129
expected_tensor = original_state_dict[param_name]
@@ -137,22 +140,25 @@ def validate_loaded_tensors_equals_original(
137140
else:
138141
tensor_to_compare = expected_tensor.cpu().float()
139142

143+
# Training trainer emitted and loaded tensors are of type bfloat16,
144+
# where as a HF model loaded(expected) tensor has type float16.
140145
if not torch.allclose(
141146
loaded_tensor.float(),
142147
tensor_to_compare,
143-
rtol=1e-5,
144-
atol=1e-8,
148+
rtol=1e-2,
149+
atol=1e-3,
145150
):
146-
validation_errors.append(
151+
logger.warning(
152+
f"Loaded tensor {param_name} does not equal original. \ndtype = {loaded_tensor.dtype} vs {expected_tensor.dtype}\n"
153+
f"shape= {loaded_tensor.shape} vs {expected_tensor.shape}\n, values = {copy_of_loaded_tensor} vs {copy_of_expected_tensor}"
154+
)
155+
raise ValueError(
147156
f"Loaded tensor {param_name} does not equal original "
148157
f"(shapes: loaded={loaded_tensor.shape}, expected={tensor_to_compare.shape})"
149158
)
150159
else:
151160
print(f"Loaded tensor {param_name} correctly validated")
152161

153-
if validation_errors:
154-
raise ValueError(f"Validation failed: {validation_errors}")
155-
156162
print(
157163
f"Successfully validated that all {len(loaded_state_dict)} loaded tensors equal original"
158164
)
@@ -265,7 +271,7 @@ async def test_llama3_policy_update_single(setup_test):
265271

266272
# validating for single resource case.
267273
validate_loaded_tensors_equals_original(
268-
loaded_state_dict, expected_state_dict, tensor_parallel_size=0, rank=0
274+
loaded_state_dict, expected_state_dict, tensor_parallel_size=0, rank=0
269275
)
270276
print(
271277
"Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!"

0 commit comments

Comments
 (0)